diff --git a/.bazelrc b/.bazelrc index 396b84f70b3..aaf1e10a5c7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -159,6 +159,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl --define=build_with_openmp=true build:mkl -c opt # config to build OneDNN backend with a user specified threadpool. @@ -172,6 +173,7 @@ build:mkl_threadpool -c opt build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=build_with_mkl_opensource=true +build:mkl_opensource_only --define=build_with_openmp=true build:mkl_opensource_only -c opt # Config setting to build with oneDNN for Arm. @@ -283,7 +285,7 @@ build:ios --copt=-w build:linux --copt=-w build:linux --host_copt=-w build:macos --copt=-w -build:windows --copt=/w +build:windows --copt=/W0 # Tensorflow uses M_* math constants that only get defined by MSVC headers if # _USE_MATH_DEFINES is defined. @@ -294,9 +296,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES build:linux --define=PREFIX=/usr build:linux --define=LIBDIR=$(PREFIX)/lib build:linux --define=INCLUDEDIR=$(PREFIX)/include +build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include build:macos --define=PREFIX=/usr build:macos --define=LIBDIR=$(PREFIX)/lib build:macos --define=INCLUDEDIR=$(PREFIX)/include +build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. diff --git a/README.md b/README.md index 31888cfbbc6..63d85ce2df4 100644 --- a/README.md +++ b/README.md @@ -103,23 +103,22 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) -**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) - +Build Type | Status | Artifacts +----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) ### Community Supported Builds @@ -151,6 +150,7 @@ Build Type * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) +* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow) * [TensorFlow Chat Room on StackOverflow (not actively monitored by the TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) * [TensorFlow Blog](https://blog.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index 23324d56ca7..d5654424afd 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,35 @@ +# Release 2.5.0 + + + +## Breaking Changes + +* +* + +## Known Caveats + +* +* +* + +## Major Features and Improvements + +* +* + +## Bug Fixes and Other Changes + +* +* +* + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + # Release 2.4.0 @@ -6,6 +38,15 @@ * * +* Certain float32 ops run in lower precsion on Ampere based GPUs, including + matmuls and convolutions, due to the use of + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). + Specifically, inputs to such ops are rounded from 23 bits of precision to 10 + bits of precision. This is unlikely to cause issues in practice for deep + learning models. In some cases, TensorFloat-32 is also used for complex64 ops. + TensorFloat-32 can be disabled by running + `config.experimental.enable_tensor_float_32_execution(False)`. The "Major + Features and Improvements" section has more details. * The byte layout for string tensors across the C-API has been updated to match TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. * C-API functions `TF_StringDecode`, `TF_StringEncode`, and @@ -54,6 +95,42 @@ tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). * `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please use `tf.data.Dataset.from_tensor_slices` instead. +* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`, + `tf.distribute.StrategyExtended.batch_reduce_to`, + `tf.distribute.ReplicaContext.all_reduce` are renamed to `options`. + `tf.distribute.experimental.CollectiveHints` is renamed + `tf.distribute.experimental.CommunicationOptions`. + `tf.distribute.experimental.CollectiveCommunication` is renamed + `tf.distribute.experimental.CommunicationImplementation`. +* `tf.keras.mixed_precision.experimental`: + * `AutoCastVariable.dtype` now refers to the actual variable dtype, not the + dtype it will be casted to. + * When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a + float16 or bfloat16 tensor instead of a float32 tensor. + * The property + `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now + a tensor, not a `LossScale` object. This means to get a loss scale of a + `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead + of `opt.loss_scale()`. + * The property `should_cast_variables` has been removed from + `tf.keras.mixed_precision.experimental.Policy` + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to + `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the + `DynamicLossScale`'s multiplier must be 2. + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to + `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of + the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of + being reused. This means modifying the weights of the `DynamicLossScale` + will no longer affect the weights of the LossScaleOptimizer, and vice versa. + * The global policy can no longer be set to a non-floating point policy in + `tf.keras.mixed_precision.experimental.set_policy` + * In `Layer.call`, `AutoCastVariable`s will no longer be casted within + `MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a + thread local variable is used to determine whether `AutoCastVariable`s are + casted, and those two functions run with a different thread. Note this only + applies if one of these two functions is called within `Layer.call`; if one + of those two functions calls `Layer.call`, `AutoCastVariable`s will still be + casted. ## Known Caveats @@ -65,9 +142,40 @@ * * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. +* Support for + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) + on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a + math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as + matrix multiplications and convolutions, to run much faster on Ampere GPUs but + with reduced precision. This reduced precision has not been found to effect + convergence quality of deep learning models in practice. TensorFloat-32 is + enabled by default, but can be disabled with + `tf.config.experimental.enable_tensor_float_32_execution`. * `tf.distribute`: + * `MultiWorkerMirroredStrategy` is graduated out of experimental. + * Peer failure will no longer cause the cluster to hang. + * Major issues with saving are fixed. + * See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial. * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. +* The `tf.keras.mixed_precision` API has been made non-experimental. The major + changes to the new non-experimental API are: + * `tf.keras.mixed_precision.Policy` no longer takes in a + `tf.mixed_precision.experimental.LossScale` in the constructor, and no + longer has a `LossScale` associated with it. Instead, `Model.compile` will + automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic + loss scaling if `Policy.name` is "mixed_float16". + * `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in + different arguments. In particular, it no longer takes in a `LossScale`, and + there is no longer a `LossScale` associated with the `LossScaleOptimizer`. + Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss + scaling. See the documentation of + `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on + the differences between the experimental `LossScaleOptimizer` and the new + non-experimental `LossScaleOptimizer`. + * `tf.mixed_precision.experimental.LossScale` and its subclasses are + deprecated, as all of its functionality now exists within + `tf.keras.mixed_precision.LossScaleOptimizer` ## Bug Fixes and Other Changes @@ -117,6 +225,10 @@ ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) + * Fixes a segfault in `tf.quantization.quantize_and_dequantize` + ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265)) + * Fixes an undefined behavior float cast causing a crash + ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266)) * TF Core: * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as type annotation for variables representing a Tensor or a value @@ -138,6 +250,8 @@ stateful ops. * Added `tf.config.experimental.get_memory_usage` to return total memory usage of the device. + * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. + * Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions. * `tf.data`: * tf.data service: * Added new `tf.data.experimental.service.register_dataset` and @@ -182,7 +296,16 @@ how many times the function is called, and independent of global seed settings. * `tf.distribute`: - * + * (Experimental) Parameter server training: + * Replaced the existing + `tf.distribute.experimental.ParameterServerStrategy` symbol with + a new class that is for parameter server training in TF2. Usage with + the old symbol, usually with Estimator, should be replaced with + `tf.compat.v1.distribute.experimental.ParameterServerStrategy`. + * Added `tf.distribute.experimental.coordinator.*` namespace, + including the main API `ClusterCoordinator` for coordinating the + training cluster, the related data structure `RemoteValue` + and `PerWorkerValue`. * `tf.keras`: * Improvements from the functional API refactoring: * Functional model construction does not need to maintain a global @@ -217,6 +340,8 @@ * Improvements to Keras preprocessing layers: * TextVectorization can now accept a vocabulary list or file as an init arg. + * TextVectorization, StringLookup, and IntegerLookup can now accept a + vocabulary file via the `set_vocab_from_file` method. * Normalization can now accept mean and variance values as init args. * In `Attention` and `AdditiveAttention` layers, the `call()` method now accepts a `return_attention_scores` argument. When set to @@ -224,6 +349,15 @@ argument. * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints with the same implementation as their `tf.losses` equivalent. + * For Keras model, the individual call of `Model.evaluate` uses no cached + data for evaluation, while `Model.fit` uses cached data when + `validation_data` arg is provided for better performance. + * Added a `save_traces` argument to `model.save`/ + `tf.keras.models.save_model` which determines whether the SavedModel + format stores the Keras model/layer call functions. The traced functions + allow Keras to revive custom models and layers without the original + class definition, but if this isn't required the tracing can be + disabled with the added option. * `tf.function` / AutoGraph: * Added `experimental_follow_type_hints` argument for `tf.function`. When True, the function may use type annotations to optimize the tracing @@ -269,6 +403,7 @@ `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty. + * Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion. * * `tf.random`: @@ -277,7 +412,7 @@ * Math and Linear Algebra: - * + * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`. * TPU Enhancements: @@ -323,6 +458,12 @@ didn't have the keys sorted, the keys and values were not being printed in accordance with their correct mapping. +* `TensorRT` + + * We now issue a warning when the `session_config` parameter for the TF1 + converter is used or the `rewrite_config_template` field in the TF2 + converter parameter object is used. + * Other: * We have replaced uses of "whitelist" and "blacklist" with "allowlist" @@ -331,6 +472,8 @@ context. * Add `tf.config.experimental.mlir_bridge_rollout` which will help us rollout the new MLIR TPU bridge. + * Added `tf.experimental.register_filesystem_plugin` to load modular + filesystem plugins from Python * ## Thanks to our Contributors @@ -703,6 +846,7 @@ stjohnso98, , , , , * Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights. * Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights. * Mutable tables now restore checkpointed values when loaded from SavedModel. + * The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update. * GPU * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. * Remove environmental variable `TF_USE_CUDNN`. @@ -731,6 +875,7 @@ stjohnso98, , , , , * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. + * Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices. ### `tf.keras`: * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. diff --git a/configure.py b/configure.py index e381c8c20db..b4907775d93 100644 --- a/configure.py +++ b/configure.py @@ -1163,12 +1163,9 @@ def set_system_libs_flag(environ_cp): syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - if 'PREFIX' in environ_cp: - write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) - if 'LIBDIR' in environ_cp: - write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) - if 'INCLUDEDIR' in environ_cp: - write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) + for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): + if varname in environ_cp: + write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) def is_reduced_optimize_huge_functions_available(environ_cp): diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 420c7479e50..274a829f575 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -3,6 +3,7 @@ # learning applications. load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load( "//tensorflow/core/platform:build_config.bzl", @@ -238,6 +239,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_mips64", + values = {"cpu": "mips64"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -563,18 +570,45 @@ selects.config_setting_group( ], ) +# 'enable_registration_v2' opts-in to a different implementation of op and +# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc. +# +# This setting is currently experimental. The 'v2' implementation does _not_ +# correspond to a particular, finalized design; rather, it relates to +# developing one. +# +# The current aim of the 'v2' implementation is to allow 'unused' ops and +# kernels to be discarded by the linker (to the benefit of binary size). +bool_flag( + name = "enable_registration_v2", + build_setting_default = False, + visibility = ["//visibility:public"], +) + +config_setting( + name = "registration_v1", + flag_values = {":enable_registration_v2": "False"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "registration_v2", + flag_values = {":enable_registration_v2": "True"}, + visibility = ["//visibility:public"], +) + # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # Instead, please use public APIs or public build rules TF provides. # If you need functionality that is not exposed, we will work with you to expand our public APIs. package_group( name = "internal", - packages = ["//tensorflow/..."], + packages = [ + "//learning/lib/ami/simple_ml/...", + "//tensorflow/...", + ], ) -package_group( - name = "ndarray_tensor_allow_list", - packages = ["//learning/pathways/..."], -) +package_group(name = "ndarray_tensor_allow_list") # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. @@ -606,6 +640,7 @@ bzl_library( "//third_party/mkl:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl", "//third_party/ngraph:build_defs_bzl", + "@bazel_skylib//rules:common_settings", "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", @@ -706,6 +741,9 @@ tf_cc_shared_object( visibility = ["//visibility:public"], deps = [ "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c:kernels_hdrs", + "//tensorflow/c:ops_hdrs", "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core:framework_internal_impl", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 693bae5cfc7..ce3f6c9f877 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -202,6 +202,7 @@ tf_cuda_library( ":tf_status", ":tf_tensor", "@com_google_absl//absl/strings", + "//tensorflow/c/experimental/filesystem:modular_filesystem", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", @@ -511,6 +512,18 @@ tf_cuda_library( ], ) +cc_library( + name = "kernels_hdrs", + hdrs = ["kernels.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":c_api_internal", + ":tf_datatype", + ":tf_status", + ":tf_tensor", + ], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -565,6 +578,16 @@ tf_cuda_library( alwayslink = 1, ) +cc_library( + name = "ops_hdrs", + hdrs = ["ops.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":tf_datatype", + ":tf_status", + ], +) + # ----------------------------------------------------------------------------- # Tests diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a03e9227a75..9579efab94d 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" // NOLINT #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/c/experimental/filesystem/modular_filesystem.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" @@ -2606,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) { #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) } +void TF_RegisterFilesystemPlugin(const char* plugin_filename, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "FileSystem plugin functionality is not supported on mobile"); +#else + status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index db5f8fd68f8..f550b690e27 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1577,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); TF_CAPI_EXPORT extern void TF_RegisterLogListener( void (*listener)(const char*)); +// Register a FileSystem plugin from filename `plugin_filename`. +// +// On success, place OK in status. +// On failure, place an error status in status. +TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( + const char* plugin_filename, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 2f3dba8a9ce..0cc29af7117 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -563,15 +563,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status) { +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status) { tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto collective_executor_handle = context->GetCollectiveExecutorHandle(); tensorflow::Notification done; collective_executor_handle->get()->remote_access()->CheckPeerHealth( - task, [&done, status](const Status& s) { + task, timeout_in_ms, [&done, status](const Status& s) { status->status = s; done.Notify(); }); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 317e2bf4a91..832d2db6ab2 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, // Checks the health of collective ops peers. Explicit health check is needed in // multi worker collective ops to detect failures in the cluster. If a peer is // down, collective ops may hang. -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status); // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index b90b2644269..c44d0ee6873 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -10,6 +10,9 @@ load( "tf_cuda_library", ) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") @@ -94,6 +97,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core:gpu_runtime", ] + internal_tfrt_deps(), alwayslink = 1, @@ -638,6 +642,19 @@ cc_library( ], ) +cc_header_only_library( + name = "tfe_tensorhandle_internal_hdrs_only", + extra_deps = [ + "@com_google_absl//absl/strings", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tfe_tensorhandle_internal", + ], +) + tf_cuda_library( name = "c_api_test_util", testonly = 1, diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5f388bfe0cd..3418bccf050 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -70,6 +70,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #endif // !IS_MOBILE_PLATFORM #include "tensorflow/core/framework/node_def_util.h" @@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - // TODO(yuefengz): support partially specified `worker_name`. - tensorflow::core::RefCountPtr eager_client; - status->status = context->GetClient(worker_name, &eager_client); - if (!status->status.ok()) { + tensorflow::GrpcServer* grpc_server = + dynamic_cast(context->GetServer()); + if (grpc_server == nullptr) { + status->status = + tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer."); + return false; + } + tensorflow::WorkerInterface* wi = + grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name); + if (wi == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Unable to find worker interface corresponding to task ", worker_name); return false; } - // Send a rpc request to the worker to check aliveness. - tensorflow::eager::KeepAliveRequest request; - request.set_context_id(context->GetContextId()); - tensorflow::eager::KeepAliveResponse response; - - tensorflow::Status keep_alive_status; + tensorflow::GetStatusRequest request; + tensorflow::GetStatusResponse response; + tensorflow::Status remote_status; tensorflow::Notification done; - eager_client->KeepAliveAsync( - &request, &response, - [&keep_alive_status, &done](const tensorflow::Status& s) { - keep_alive_status = s; - done.Notify(); - }); + wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, + [&remote_status, &done](const tensorflow::Status& s) { + remote_status = s; + done.Notify(); + }); done.WaitForNotification(); + // We set OK status so the call does not raise any exceptions. Instead, caller + // users the return value to tell if the remote worker is alive. status->status = tensorflow::Status::OK(); - // If `context_id` doesn't exist on the remote worker, an InvalidArgument - // error will return. But this still indicates that the remote worker is - // alive. - if (keep_alive_status.ok() || - keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) { + if (remote_status.ok()) { return true; - } else { - LOG(INFO) << "Remote worker " << worker_name - << " is not alive: " << keep_alive_status.error_message(); - return false; } + LOG(INFO) << "Remote worker " << worker_name + << " is not alive: " << remote_status.error_message(); + return false; #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index cc2270755bf..1ef536a66f6 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); } + +const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return nullptr; + } + return tensorflow::unwrap(h)->DeviceType(&status->status); +} + +int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return -1; + } + return tensorflow::unwrap(h)->DeviceId(&status->status); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 12546c6082a..d0739a5437d 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status); +// Returns the device type of the operation that produced `h`. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the device ID of the operation that produced `h`. +TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 4975d303375..4fe83b5116d 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) { TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + const char* device_type = TFE_TensorHandleDeviceType(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_type, nullptr); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int device_id = TFE_TensorHandleDeviceID(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_id, -1); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); +} + +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + device_type = TFE_TensorHandleDeviceType(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type; + + device_id = TFE_TensorHandleDeviceID(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TensorHandleDefaults) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice( + h_default, ctx, "/device:CPU:0", status.get()); + const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu; + int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id_cpu) << device_id_cpu; + + TFE_DeleteTensorHandle(h_default); + TFE_DeleteTensorHandle(h_cpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 7a438085fb5..393ad2ceb98 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) { } TEST_P(GradientCheckerTest, TestGradCheckMatMul) { + // Computing numerical gradients with TensorFloat-32 is numerically unstable + enable_tensor_float_32_execution(false); + std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index cd4febba8c1..84ba0e061cc 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -62,10 +62,11 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer)); return Status::OK(); } - // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -74,11 +75,11 @@ Status AddGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); // Compute x+y. @@ -97,7 +98,6 @@ Status AddGradModel(AbstractContext* ctx, } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; - delete tape; return Status::OK(); } @@ -109,10 +109,10 @@ Status ExpGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector exp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR( ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); std::unordered_map @@ -128,7 +128,6 @@ Status ExpGradModel(AbstractContext* ctx, exp_output->Unref(); } outputs[0] = out_grads[0]; - delete tape; return Status::OK(); } @@ -140,10 +139,10 @@ Status SqrtGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector sqrt_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR( ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); std::unordered_map @@ -159,7 +158,6 @@ Status SqrtGradModel(AbstractContext* ctx, sqrt_output->Unref(); } outputs[0] = out_grads[0]; - delete tape; return Status::OK(); } @@ -172,12 +170,12 @@ Status IdentityNGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); tape->Watch(ToId(inputs[1])); vector identity_n_outputs(2); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR(ops::IdentityN( tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); @@ -195,7 +193,71 @@ Status IdentityNGradModel(AbstractContext* ctx, } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; - delete tape; + return Status::OK(); +} + +// Computes +// y = - inputs[0] +// return grad(y, {inputs[0]}) +Status NegGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = std::make_unique(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); + + std::vector neg_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); + TF_RETURN_IF_ERROR( + ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg")); + + std::unordered_map + source_tensors_that_are_targets; + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto neg_output : neg_outputs) { + neg_output->Unref(); + } + outputs[0] = out_grads[0]; + return Status::OK(); +} + +// Computes +// y = inputs[0] - inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status SubGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = std::make_unique(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector sub_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); + TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs, + absl::MakeSpan(sub_outputs), + "Sub")); // Compute x-y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto sub_output : sub_outputs) { + sub_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; return Status::OK(); } @@ -536,6 +598,111 @@ TEST_P(CppGradients, TestIdentityNGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestNegGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = - x + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +TEST_P(CppGradients, TestSubGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // tape.watch(y) + // y = x - y + // outputs = tape.gradient(y, [x, y]) + std::vector outputs(2); + s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; + + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); +} + TEST_P(CppGradients, TestSetAttrString) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -575,7 +742,7 @@ TEST_P(CppGradients, TestSetAttrString) { int num_retvals = 1; std::vector outputs(1); GradientRegistry registry; - std::unique_ptr tape(new Tape(/*persistent=*/false)); + auto tape = std::make_unique(/*persistent=*/false); s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), &num_retvals, &forward_op, tape.get(), registry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index 6d32d482747..bb6d471f12f 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { virtual const char* DeviceName(Status* status) const = 0; // Returns the device where the tensor was placed. virtual const char* BackingDeviceName(Status* status) const = 0; + // Returns the device type which created the handle. + virtual const char* DeviceType(Status* status) const = 0; + // Returns the device ID which created the handle. + virtual int DeviceId(Status* status) const = 0; // Returns a tensor for the handle. If tensor is remote, it will be copied. virtual AbstractTensorInterface* Resolve(Status* status) = 0; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index 4114f50a798..16cb01110fd 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -43,6 +44,11 @@ class CppGradients TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.error_message(); + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); } }; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index e270bfcbb80..095f33ff303 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr; class DeviceThread { public: // Starts a background thread waiting for `StartExecute`. - explicit DeviceThread(const std::string& device) + explicit DeviceThread(const std::string& device, const bool is_async) : status_(TF_NewStatus()), device_(device), // If the context's default exector is set to async, re-using that in @@ -67,7 +67,7 @@ class DeviceThread { // // TODO(allenl): We should have an async API that works with the // parallel device. - executor_(TFE_NewExecutor(/*is_async=*/false)), + executor_(TFE_NewExecutor(is_async)), op_(nullptr), thread_(tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "parallel_device_execute", @@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, } } -ParallelDevice::ParallelDevice(const std::vector& devices) +ParallelDevice::ParallelDevice(const std::vector& devices, + const bool is_async) : underlying_devices_(devices) { device_threads_.reserve(devices.size()); for (int device_index = 0; device_index < devices.size(); ++device_index) { device_threads_.emplace_back( - new DeviceThread(devices[device_index].c_str())); + new DeviceThread(devices[device_index].c_str(), is_async)); } } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index b3dc47ab088..1bb9ce0f663 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -49,7 +49,10 @@ class DeviceThread; // placed on each underlying device. class ParallelDevice { public: - explicit ParallelDevice(const std::vector& devices); + // Eager async execution is only supported when remote eager is not in use + // (b/157523095). + explicit ParallelDevice(const std::vector& devices, + const bool is_async = false); ~ParallelDevice(); diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index 5ff28e4229a..3e96b054949 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); std::string cacheKey(scheme); - hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); if (scheme == "file") { - libhdfs->hdfsBuilderSetNameNode(builder, nullptr); + namenode = ""; } else if (scheme == "viewfs") { char* defaultFS = nullptr; libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); @@ -200,21 +199,24 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, // The default NameNode configuration will be used (from the XML // configuration files). See: // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 - libhdfs->hdfsBuilderSetNameNode(builder, "default"); + namenode = "default"; } else if (scheme == "har") { std::string path_har = path; SplitArchiveNameAndPath(&path_har, &namenode, status); if (TF_GetCode(status) != TF_OK) return nullptr; - libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); - cacheKey += namenode; } else { - libhdfs->hdfsBuilderSetNameNode( - builder, namenode.empty() ? "default" : namenode.c_str()); - cacheKey += namenode; + if (namenode.empty()) { + namenode = "default"; + } } + cacheKey += namenode; + absl::MutexLock l(&hadoop_file->connection_cache_lock); if (hadoop_file->connection_cache.find(cacheKey) == hadoop_file->connection_cache.end()) { + hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); + libhdfs->hdfsBuilderSetNameNode( + builder, namenode.empty() ? nullptr : namenode.c_str()); auto cacheFs = libhdfs->hdfsBuilderConnect(builder); if (cacheFs == nullptr) { TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 5cba7b28fda..5551642127d 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -24,6 +24,7 @@ using std::vector; using tensorflow::ops::Conj; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; +using tensorflow::ops::Neg; using tensorflow::ops::SqrtGrad; namespace tensorflow { @@ -201,6 +202,56 @@ class MatMulGradientFunction : public GradientFunction { AttrBuilder forward_attrs; }; +class NegGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Neg op Y = -X, the gradients are: + * + * dX = -U + * + */ + + grad_outputs->resize(1); + std::string name = "Neg_Grad"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~NegGradientFunction() override {} +}; + +class SubGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Sub op A-B, the gradients are: + * + * dA = U + * dB = -U + * + */ + + grad_outputs->resize(2); + + // Grad for A + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[0]->Ref(); + + // Grad for B + // negate the upstream grad + std::vector neg_outputs(1); + std::string name = "Neg_Sub_Grad_B"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(neg_outputs), name.c_str())); + (*grad_outputs)[1] = neg_outputs[0]; + + return Status::OK(); + } + ~SubGradientFunction() override {} +}; + } // namespace BackwardFunction* AddRegisterer(const ForwardOperation& op) { @@ -239,5 +290,23 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* NegRegisterer(const ForwardOperation& op) { + auto gradient_function = new NegGradientFunction; + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* SubRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new SubGradientFunction; + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 7faeadcca81..756c5f84153 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -24,6 +24,8 @@ BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* SqrtRegisterer(const ForwardOperation& op); +BackwardFunction* NegRegisterer(const ForwardOperation& op); +BackwardFunction* SubRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 94418aa9ad9..3748a6ae9ec 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -11,11 +11,21 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "stream_executor_hdrs", + hdrs = ["stream_executor.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + ], +) + cc_library( name = "stream_executor", srcs = ["stream_executor.cc"], hdrs = ["stream_executor.h"], - visibility = ["//visibility:public"], + visibility = ["//tensorflow:internal"], deps = [ ":stream_executor_internal", "//tensorflow/c:c_api_macros", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 70d080a682f..dcd652d9fdf 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -404,10 +404,12 @@ Status RestoreSession(const RunOptions& run_options, const uint64 read_start_microseconds = Env::Default()->NowMicros(); std::vector asset_file_defs; TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); - TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, - meta_graph.saver_def().restore_op_name(), - meta_graph.saver_def().filename_tensor_name(), - asset_file_defs, session->get())); + if (meta_graph.has_saver_def()) { + TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, + meta_graph.saver_def().restore_op_name(), + meta_graph.saver_def().filename_tensor_name(), + asset_file_defs, session->get())); + } // Record walltime spent in restoring graph from disk, but postpone metric // increments until graph init finishes. const uint64 restore_graph_walltime = diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d87b5f0715..7bd7b5964f9 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -7,6 +7,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_ load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") @@ -283,6 +286,7 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], @@ -291,7 +295,7 @@ cc_library( # Header-only version of "flags" library, for linking from the shared object # without ODR violations. cc_library( - name = "flags_headers_only", + name = "flags_headers", hdrs = ["flags.h"], visibility = [":friends"], deps = [ @@ -302,6 +306,11 @@ cc_library( ], ) +cc_header_only_library( + name = "flags_headers_only", + deps = [":flags_headers"], +) + cc_library( name = "common", srcs = [ @@ -447,8 +456,8 @@ cc_library( # Header-only version of "flags" library, for linking from the shared object # without ODR violations. cc_library( - name = "get_compiler_ir_hdrs_only", - hdrs = ["get_compiler_ir.h"], + name = "get_compiler_ir_hdrs", + textual_hdrs = ["get_compiler_ir.h"], visibility = [ ":internal", "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", @@ -463,6 +472,23 @@ cc_library( ], ) +cc_header_only_library( + name = "get_compiler_ir_hdrs_only", + deps = [":get_compiler_ir_hdrs"], +) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) + cc_library( name = "xla_kernel_creator", srcs = [ @@ -842,9 +868,12 @@ tf_cc_test( "partially_decluster_pass_test.cc", "rearrange_function_argument_pass_test.cc", ], - # TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value - # error. - tags = ["nomsan"] + tf_cuda_tests_tags(), + tags = [ + # TODO(b/141643254) Re-enable msan after fixing + # use-of-uninitialized-value error. + "nomsan", + "no_cuda_asan", # TODO(b/171317460): re-enable. + ] + tf_cuda_tests_tags(), deps = [ ":common", ":compilability_check_util", @@ -1075,15 +1104,3 @@ cc_library( ], alwayslink = 1, ) - -# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index ee7daf092da..52d8fb94ff6 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -167,8 +167,16 @@ void AllocateAndParseFlags() { jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; - mlir_flags = new MlirCommonFlags; - mlir_flags->tf_mlir_enable_mlir_bridge = false; + // The `enable_mlir_bridge` flag allows the user to explicitly request that + // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge. + // + // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the + // user has made an explicit request. That is, if this variable is set to + // true, the program honors the user's request as per `enable_mlir_bridge`; if + // it's set to false, the default behavior is used (which may run either + // bridge, on a per-graph basis). + bool enable_mlir_bridge = false; + bool enable_mlir_bridge_is_explicit = false; auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); @@ -217,12 +225,24 @@ void AllocateAndParseFlags() { "The amount of jitter to introduce. This amount is added to each " "element in the tensors named in `tensor_names."), - Flag("tf_mlir_enable_mlir_bridge", - &mlir_flags->tf_mlir_enable_mlir_bridge, - "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); + Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, + "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", + &enable_mlir_bridge_is_explicit)}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); + + mlir_flags = new MlirCommonFlags; + if (!enable_mlir_bridge_is_explicit) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } else if (enable_mlir_bridge) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } else { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + } } } // namespace diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5612b3b5864..a0860da7b04 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags { // Flags for common MLIR configurations. struct MlirCommonFlags { - bool tf_mlir_enable_mlir_bridge; + ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; }; // Return a pointer to the DumpGraphFlags struct; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 12b40b1c83b..0f0f43cbad6 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -274,18 +274,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -522,18 +510,6 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index d7d5ee02265..461a6692c84 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -283,25 +283,29 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); + // TODO(b/155596779): Support TensorList args. bool has_tensor_list_arg = absl::c_any_of(args, [](const XlaCompiler::Argument arg) { return arg.kind == XlaCompiler::Argument::kTensorList; }); const ConfigProto* config = ctx->function_library()->config_proto(); - bool use_mlir = config && config->experimental().enable_mlir_bridge(); + // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR. + bool use_mlir = config && config->experimental().enable_mlir_bridge() && + !has_tensor_list_arg && + node_def.op() != "VarIsInitializedOp"; #ifdef LIBTPU_ON_GCE - if (use_mlir && has_tensor_list_arg) { + if (use_mlir) { LOG(WARNING) << "MLIR is not supported in this environment."; } return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); #else - // TODO(b/155596779): Support TensorList args. - if (!use_mlir || !has_tensor_list_arg) { + if (!use_mlir) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } + VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; std::vector control_rets; if (result_dtypes.empty()) { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index d4a69da4898..b90f8b7b990 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -89,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, XlaOpRegistry::RegisterCompilationKernels(); // Only check for compilability if the MLIR bridge is not enabled. - if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { std::vector diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 30e8d8a86a7..1636bbb89ee 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -45,6 +45,7 @@ filegroup( "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", @@ -122,8 +123,6 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", @@ -150,6 +149,24 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], + td_srcs = [":hlo_ops_td_files"], +) + +gentbl( + name = "hlo_ops_base_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], td_srcs = [":hlo_ops_td_files"], ) @@ -194,6 +211,63 @@ gentbl( td_srcs = [":hlo_ops_td_files"], ) +gentbl( + name = "lhlo_gpu_ops_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + ], +) + +cc_library( + name = "lhlo_gpu_ops_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h", + ], + includes = ["include"], + deps = [ + ":lhlo_gpu_ops_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +gentbl( + name = "lhlo_gpu_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"), + ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + ], +) + #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl( name = "canonicalize_inc_gen", @@ -251,6 +325,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hlo_ops_base_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc", + "lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h", + ], + includes = ["include"], + deps = [ + ":hlo_ops_base_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "convert_op_folder", srcs = ["lib/utils/convert_op_folder.cc"], @@ -284,6 +375,7 @@ cc_library( ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":hlo_ops_inc_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", @@ -314,6 +406,7 @@ cc_library( includes = ["include"], deps = [ ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":lhlo_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -330,6 +423,39 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lhlo_gpu", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h", + ], + includes = ["include"], + deps = [ + ":hlo", + ":hlo_ops_base_structs", + ":infer_fusibility_op_interface", + ":lhlo_gpu_ops_inc_gen", + ":lhlo_gpu_ops_structs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CopyOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", + ], + alwayslink = 1, +) + cc_library( name = "hlo_dialect_registration", srcs = ["lib/Dialect/mhlo/IR/init.cc"], @@ -337,6 +463,7 @@ cc_library( deps = [ ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//mlir:IR", ], ) @@ -385,6 +512,7 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", ], ) @@ -522,6 +650,7 @@ cc_library( "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:ViewLikeInterface", ], alwayslink = 1, ) @@ -878,6 +1007,7 @@ cc_binary( ":all_passes", ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 09bdca84cd3..3fa2b908d9c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -25,7 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) endfunction() add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo) +set(LLVM_TARGET_DEFINITIONS hlo_ops.td) +mlir_tablegen(hlo_ops.h.inc -gen-op-decls) +mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) +mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRhlo_opsIncGen) + +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) +mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) +mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) +mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) +add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) + add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 60ee4e613eb..b354189c12a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -32,7 +33,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // clang-format off -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" // clang-format on diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 507f7c11d63..579e89ca137 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" -def HLO_Dialect : Dialect { - let name = "mhlo"; - let cppNamespace = "::mlir::mhlo"; -} - class HLO_Op traits> : Op { // Whether this operation has a custom conversion to HLO or not. @@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp traits, } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputOutputShapeEquality(int input, int output) { return true; @@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultShape], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand" + "Value operand" >]; } @@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< BASE_HLO_ConvertOp { let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value operand, " - "Type result_element_ty" + "Value operand, Type result_element_ty" >]; let hasFolder = 1; @@ -247,7 +241,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real", } def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp { + let hasFolder = 1; +} def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, @@ -293,8 +289,8 @@ class HLO_BinaryElementwiseOp traits> : } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputsShapeEquality(int lhs, int rhs) { return true; @@ -458,7 +454,7 @@ def HLO_SendOp : HLO_Op<"send", []> { let arguments = (ins HLO_TensorOrTuple:$operand, HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -483,7 +479,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { let arguments = (ins HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -587,7 +583,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); @@ -959,15 +955,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> - ]> { - let description = "Structure of dimension information for dot product"; -} - def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { let arguments = (ins HLO_Tensor:$lhs, @@ -1029,14 +1016,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let results = (outs HLO_Tensor); } -def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, - [StructFieldAttr<"offset_dims", I64ElementsAttr>, - StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, - StructFieldAttr<"start_index_map", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for gather"; -} - def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = (ins HLO_Tensor:$operand, @@ -1114,7 +1093,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, HLO_Tensor:$updates, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -1124,6 +1103,8 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; + + let hasFolder = 1; } // TODO(jpienaar): Add broadcastable trait. @@ -1220,6 +1201,8 @@ def HLO_PadOp: HLO_Op<"pad", // TODO(b/129422361): PadOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; + + let hasFolder = 1; } def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index cba2dc370f0..da8c921a47b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -18,6 +18,13 @@ limitations under the License. include "mlir/IR/OpBase.td" +def HLO_Dialect : Dialect { + let name = "mhlo"; + let cppNamespace = "::mlir::mhlo"; +} + +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" + def HLO_Pred : TypeAlias; // TODO(hinsu): Use signed integers instead of signless integer which is being @@ -614,15 +621,6 @@ class BASE_HLO_CaseOp { // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// -// Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. -class ChannelHandle : StructAttr<"ChannelHandle", dialect, [ - StructFieldAttr<"handle", I64Attr>, - StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; -} - class BASE_HLO_ReplicaIdOp { string summary = "ReplicaId operator"; @@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr: OptionalAttr< TypedArrayAttrBase>; + //===----------------------------------------------------------------------===// // Fast Fourier Transform Type enum definitions. //===----------------------------------------------------------------------===// @@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp { // Common convolution attributes //===----------------------------------------------------------------------===// -class ConvDimensionNumbersBase - : StructAttr<"ConvDimensionNumbers", dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - class ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. @@ -1036,7 +1020,7 @@ class ConvolutionAttributes { OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, - ConvDimensionNumbersBase:$dimension_numbers, + ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, HLO_PrecisionConfigAttr:$precision_config @@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp { }]; } -class ScatterDimensionNumbers : StructAttr< - "ScatterDimensionNumbers", dialect, [ - StructFieldAttr<"update_window_dims", I64ElementsAttr>, - StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, - StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for scatter"; -} - class BASE_HLO_ScatterOp { string summary = "Scatter operator"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h new file mode 100644 index 00000000000..3b78ff8a367 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in MHLO and LMHLO. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td new file mode 100644 index 00000000000..d25eb5104c6 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE_STRUCTS +#define HLO_OPS_BASE_STRUCTS + +//===----------------------------------------------------------------------===// +// Dot dimensions enum definitions. +//===----------------------------------------------------------------------===// + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +def ScatterDimensionNumbers : StructAttr< + "ScatterDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"update_window_dims", I64ElementsAttr>, + StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, + StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for scatter"; +} + +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +#endif // HLO_OPS_BASE_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h new file mode 100644 index 00000000000..effa9ecc83b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +class OpBuilder; +} // namespace mlir + + +namespace mlir { +namespace lmhlo_gpu { + +class LmhloGpuDialect : public Dialect { + public: + explicit LmhloGpuDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo_gpu"; } +}; + +} // namespace lmhlo_gpu +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td new file mode 100644 index 00000000000..b3708bf4ff1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LHMLO level GPU operations. +// Because these are LMHLO level operations, they operate on memrefs. + +#ifndef LHLO_GPU_OPS +#define LHLO_GPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" + + +class LHLOGPU_Op traits = []> : + Op], traits)>; + +// Type for scratch buffers used by GPU library calls (memref) +def UntypedBuffer : MemRefRankOf<[I8], [1]>; + +// Cholesky info output buffer type. +def I32Buffer : MemRefOf<[I32]>; + +//===----------------------------------------------------------------------===// +// LMHLO ops representing batch norm library functions. +//===----------------------------------------------------------------------===// + +// Note: these are semantically different from similar LHLO as the GPU library +// calls generate or consume standard deviation, whereas LHLO ops generate or +// consume variance (= std-dev ^ 2). + +def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, + BASE_HLO_BatchNormGradOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$stddev, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, + BASE_HLO_BatchNormInferenceOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$stddev, + Arg:$output, + F32Attr:$epsilon, + I64Attr:$feature_index); +} + +def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_stddev, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing convolution library functions. +//===----------------------------------------------------------------------===// + +def ActivationModeNone : StrEnumAttrCase<"None">; +def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; +def ActivationModeTanh : StrEnumAttrCase<"Relu">; +def ActivationModeRelu : StrEnumAttrCase<"Relu">; +def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; +def ActivationModeReluX : StrEnumAttrCase<"ReluX">; +def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; + +def ActivationAttr : StrEnumAttr<"Activation", + "Activation applied with fused convolution", + [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, + ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, + ActivationModeBandPass]>; + +def GpuConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def GpuFusedConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale, + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { + let arguments = !con( + (ins + Arg:$d_output, + Arg:$filter, + Arg:$d_input, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$d_output, + Arg:$d_filter, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +// output = activation(result_scale * conv(input, filter) + +// side_input * side_input_scale + +// bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$side_input, + Arg:$output, + Arg:$scratch), + GpuFusedConvolutionAttributes.attributes); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing other library functions. +//===----------------------------------------------------------------------===// + +// output = alpha * (lhs * rhs) +// Verify: beta = 0.0 +def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +// output = alpha(lhs * rhs) + beta * bias +def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$bias, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + F64Attr:$beta, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { + let arguments = (ins + Arg:$input, + Arg:$output, + Arg:$scratch, + Arg:$info, + BoolAttr:$is_upper); +} + +#endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td new file mode 100644 index 00000000000..820e4ce64b0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// We define the dialect here so that both structs and ops can refer to it. + +#ifndef LHLO_GPU_OPS_BASE +#define LHLO_GPU_OPS_BASE + +include "mlir/IR/OpBase.td" + +def LHLO_GPU_Dialect : Dialect { + let name = "lmhlo_gpu"; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +#endif // LHLO_GPU_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h new file mode 100644 index 00000000000..ff642b82c22 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +// This file defines structures used in the LMHLO_GPU dialect. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td new file mode 100644 index 00000000000..2236fc38e29 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td @@ -0,0 +1,29 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_GPU_OPS_STRUCTS +#define LHLO_GPU_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" + +def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", + LHLO_GPU_Dialect, [ + StructFieldAttr<"algorithm", I64Attr>, + StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { + let description = "GPU Convolution backend configuration"; +} + +#endif // LHLO_GPU_OPS_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index cc24e17c001..9dc6d7aa0c0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the LXLA dialect. +// This file defines the operations used in the LHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -33,11 +34,6 @@ limitations under the License. namespace mlir { class OpBuilder; -} // namespace mlir - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" - -namespace mlir { namespace lmhlo { class LmhloDialect : public Dialect { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index c013939c544..28e51351c7e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { let arguments = (ins Arg:$lhs, Arg:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, HLO_PrecisionConfigAttr:$precision_config, Arg:$output ); @@ -601,11 +602,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { let arguments = (ins Arg:$operand, Arg:$start_indices, - I64Attr:$index_vector_dim, - I64ElementsAttr:$offset_dims, + GatherDimensionNumbers:$dimension_numbers, I64ElementsAttr:$slice_sizes, - I64ElementsAttr:$collapsed_slice_dims, - I64ElementsAttr:$start_index_map, Arg:$output ); } @@ -623,7 +621,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { Arg:$scatter_indices, Arg:$updates, Arg:$output, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -699,7 +697,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, Arg:$output, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, - OptionalAttr>:$channel_id, + OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); @@ -712,7 +710,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, Arg:$operand, Arg:$output, I64ElementsAttr:$source_target_pairs, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 9cf1b6ce57e..d59dfd43d1b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace lmhlo { @@ -96,7 +97,7 @@ template struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return b->template create(loc, result_types, args, mlir::None); @@ -120,7 +121,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -130,8 +131,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); auto neg_val = b->create>(loc, zero_intval, lhs); @@ -196,7 +200,7 @@ inline Value MapCompareOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; - Type element_type = lhs.getType(); + Type element_type = getElementTypeOrSelf(lhs.getType()); if (element_type.isSignlessInteger()) { Optional predicate = getCmpPredicate(comparison_direction); @@ -268,8 +272,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type sourceType = args.front().getType(); - Type targetType = result_types.front(); + Type sourceType = getElementTypeOrSelf(args.front().getType()); + Type targetType = getElementTypeOrSelf(result_types.front()); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); @@ -390,7 +394,7 @@ struct CompareSelectOpToStdScalarOp result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { auto predicate = getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); @@ -439,7 +443,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); @@ -449,8 +453,11 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); - auto zero_intval = + Value zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); + } return b->create>(loc, zero_intval, lhs); } return nullptr; @@ -461,11 +468,14 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (auto integer_type = element_type.dyn_cast()) { // lmhlo.not(x) -> x ^ -1 - auto all_ones = + Value all_ones = b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); + } return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); } return nullptr; @@ -493,26 +503,35 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - Type element_type = args.front().getType(); + Type element_type = getElementTypeOrSelf(args.front().getType()); if (auto float_type = element_type.dyn_cast()) { bool ignored; APFloat one_apfloat(1.0f); one_apfloat.convert(float_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &ignored); Value one = b->create(loc, one_apfloat, float_type); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); } else if (auto integer_type = element_type.dyn_cast()) { // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) Value zero = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); - Value cmp = - b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( loc, integer_type.getWidth() - 1, integer_type.getWidth()); - Value ashr = - b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); Value one = b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); + bitwidth_minus_one = + b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); + one = b->create<::mlir::SplatOp>(loc, vec_type, one); + } + Value cmp = + b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); + Value ashr = + b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); } @@ -583,6 +602,27 @@ struct HloOpToStdScalarOp { return impl::MapCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } + + // Implementation for LHLO ops except lmhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(loc, result_types, args, b); + } + + // Implementation for lmhlo::CompareOp. + template ::value>> + static Value map(Location loc, StringRef comparison_direction, + ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return impl::MapCompareOpToStdScalarOp( + loc, comparison_direction, result_types, args, b); + } }; } // namespace lmhlo diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index a58d0c6304c..c568165c0fb 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -27,7 +27,6 @@ namespace mlir { class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; -class BufferAssignmentPlacer; // Populates a collection of rewrite patterns to realize element-wise operations // on ranked tensors where possible. @@ -56,9 +55,9 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. -void populateHLOToLHLOConversionPattern( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns); +void populateHLOToLHLOConversionPattern(MLIRContext *context, + BufferizeTypeConverter *converter, + OwningRewritePatternList *patterns); // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt index d7bb5057b00..7c0c11b1edd 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface add_mlir_dialect_library(MhloDialect hlo_ops.cc + hlo_ops_base_structs.cc DEPENDS MLIRhlo_opsIncGen @@ -66,6 +67,15 @@ add_mlir_dialect_library(LmhloDialect ) target_link_libraries(LmhloDialect PUBLIC MLIRIR) +add_mlir_dialect_library(LmhloGPUDialect + lhlo_gpu_ops.cc + lhlo_gpu_ops_structs.cc + + DEPENDS + MLIRlhlo_gpu_opsIncGen +) +target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR) + add_mlir_dialect_library(MhloRegisterDialects init.cc @@ -73,10 +83,12 @@ DEPENDS MLIRchlo_opsIncGen MLIRhlo_opsIncGen MLIRlhlo_opsIncGen + MLIRlhlo_gpu_opsIncGen ) target_link_libraries(MhloRegisterDialects PUBLIC ChloDialect MhloDialect LmhloDialect + LmhloGPUDialect ) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 0a5bb0e018a..c04e27d50ed 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -62,8 +63,6 @@ namespace mlir { #include "hlo_patterns.cc.inc" } // namespace mlir -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" - namespace mlir { namespace mhlo { @@ -1054,6 +1053,9 @@ LogicalResult ConcatenateOp::inferReturnTypes( return success(); } + if (first_type.getRank() == 0) + return emitOptionalError(location, "rank-0 values cannot be concatenated"); + auto out_shape = llvm::to_vector<6>(first_type.getShape()); // Determine what the non-concatenate dimensions should be. @@ -1785,6 +1787,61 @@ static LogicalResult Verify(PadOp op) { return success(); } +OpFoldResult PadOp::fold(ArrayRef operands) { + // If all padding is zero then it is an identity pad. + auto is_zero = [](const APInt& i) { return i == 0; }; + if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) && + llvm::all_of(edge_padding_high().getIntValues(), is_zero) && + llvm::all_of(interior_padding().getIntValues(), is_zero)) + return operand(); + + // If any padding is negative then it isn't supported by the folder (yet). + auto is_negative = [](const APInt& i) { return i.slt(0); }; + if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) && + llvm::all_of(edge_padding_high().getIntValues(), is_negative) && + llvm::all_of(interior_padding().getIntValues(), is_negative)) + return {}; + + DenseElementsAttr input = operands[0].dyn_cast_or_null(); + DenseElementsAttr padding = operands[1].dyn_cast_or_null(); + RankedTensorType return_type = getType().dyn_cast_or_null(); + if (!input || !input.getType().hasRank() || !padding || !return_type || + !return_type.hasStaticShape()) + return {}; + + // Fill the full result tensor with the padding value. + llvm::SmallVector result(return_type.getNumElements(), + padding.getValue({})); + + auto next_index = [](llvm::SmallVector& index, + llvm::ArrayRef shape) { + for (int64_t i = index.size() - 1; i >= 0; --i) { + ++index[i]; + if (index[i] < shape[i]) return true; + index[i] = 0; + } + return false; + }; + + // Iterate over all elements of the input tensor and copy it to the correct + // location in the output tensor. + llvm::SmallVector index(input.getType().getRank(), 0); + do { + uint64_t linear_index = 0; + uint64_t linear_index_multiplyer = 1; + for (int64_t i = index.size() - 1; i >= 0; --i) { + linear_index += + (edge_padding_low().getValue({uint64_t(i)}) + + index[i] * + (interior_padding().getValue({uint64_t(i)}) + 1)) * + linear_index_multiplyer; + linear_index_multiplyer *= return_type.getShape()[i]; + } + result[linear_index] = input.getValue(index); + } while (next_index(index, input.getType().getShape())); + return DenseElementsAttr::get(return_type, result); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1931,6 +1988,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef attrs) { return DenseElementsAttr::get(type, values); } +struct round { + APFloat operator()(const APFloat& f) { + APFloat r = f; + r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway); + return r; + } +}; + #define UNARY_FOLDER(Op, Func) \ OpFoldResult Op::fold(ArrayRef attrs) { \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -1940,7 +2005,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef attrs) { return {}; \ } +#define UNARY_FOLDER_FLOAT(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder(this, attrs); \ + return {}; \ + } + UNARY_FOLDER(NegOp, std::negate); +UNARY_FOLDER_FLOAT(RoundOp, round); //===----------------------------------------------------------------------===// // BinaryOps @@ -2645,6 +2718,145 @@ OpFoldResult CompareOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +llvm::SmallVector evaluateMhloRegion(Region& region, + ArrayRef inputs) { + if (region.getNumArguments() != inputs.size()) return {}; + + llvm::DenseMap values; + values.reserve(region.getNumArguments()); + for (auto it : llvm::zip(region.getArguments(), inputs)) { + values.try_emplace(std::get<0>(it), std::get<1>(it)); + } + + for (auto& op : region.getOps()) { + llvm::SmallVector inputs; + for (auto& operand : op.getOpOperands()) { + inputs.push_back(values.lookup(operand.get())); + } + if (isa(op)) return inputs; + + llvm::SmallVector results; + if (failed(op.fold(inputs, results))) return {}; + for (auto it : llvm::zip(op.getResults(), results)) { + if (!std::get<1>(it).is()) return {}; + values.insert({std::get<0>(it), std::get<1>(it).get()}); + } + } + return {}; +} + +OpFoldResult ScatterOp::fold(ArrayRef operands) { + auto base = operands[0].dyn_cast_or_null(); + auto index = operands[1].dyn_cast_or_null(); + auto update = operands[2].dyn_cast_or_null(); + if (!base || !index || !update) return {}; + + auto base_type = base.getType().dyn_cast(); + auto index_type = index.getType().dyn_cast(); + auto update_type = update.getType().dyn_cast(); + if (!base_type || !index_type || !update_type) return {}; + + // Add the virtual trailing dimension of size 1 if index_vector_dim equals to + // index_type.rank. + const int64_t index_vector_dim = + scatter_dimension_numbers().index_vector_dim().getInt(); + if (index_vector_dim == index_type.getRank()) { + auto index_shape = index_type.getShape().vec(); + index_shape.push_back(1); + index_type = + RankedTensorType::get(index_shape, index_type.getElementType()); + index = index.reshape(index_type).cast(); + } + + // Increment the multi-dimensional index vector based on the limits for each + // dimension specified by shape and returns false if the index rolled around + // with true otherwise. + auto next_index = [](llvm::SmallVector& index, + llvm::ArrayRef shape) { + for (int64_t i = index.size() - 1; i >= 0; --i) { + ++index[i]; + if (index[i] < shape[i]) return true; + index[i] = 0; + } + return false; + }; + + // Iterate over all elements of the update tensor, then find the corresponding + // value in the indices tensor to determine which location we have to update + // in the base/result tensor. + llvm::SmallVector results(base.getValues()); + llvm::SmallVector update_index(update_type.getRank(), 0); + llvm::SmallVector index_index; + index_index.reserve(index_type.getRank()); + llvm::SmallVector base_index; + base_index.reserve(base_type.getRank()); + do { + // Compute the index for the slice of the indices tensor for this update + // value. + index_index.clear(); + if (index_vector_dim == 0) index_index.push_back(0); + for (int64_t i = 0; i < update_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0) + index_index.push_back(update_index[i]); + if (index_index.size() == index_vector_dim) index_index.push_back(0); + } + + // Compute the index for the given update value in the base tensor. + base_index.assign(base_type.getRank(), 0); + uint64_t index_count = index_type.getShape()[index_vector_dim]; + for (uint64_t i = 0; i < index_count; ++i) { + uint64_t operand_dim = scatter_dimension_numbers() + .scatter_dims_to_operand_dims() + .getValue({i}) + .getSExtValue(); + index_index[index_vector_dim] = i; + base_index[operand_dim] += + index.getValue(index_index).getSExtValue(); + } + uint64_t update_window_dim_index = 0; + for (uint64_t i = 0; i < base_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i)) + continue; + base_index[i] += + update_index[scatter_dimension_numbers() + .update_window_dims() + .getValue({update_window_dim_index}) + .getSExtValue()]; + update_window_dim_index++; + } + + // Compute the linear index for the index into the base tensor. + int64_t linear_base_index = 0; + int64_t linear_base_index_multiplyer = 1; + for (int64_t i = base_index.size() - 1; i >= 0; --i) { + // Out of bound index have backend specific behaviour so avoid folding it. + if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i]) + return {}; + linear_base_index += base_index[i] * linear_base_index_multiplyer; + linear_base_index_multiplyer *= base_type.getShape()[i]; + } + + // Evaluate update computation and update the value with the newly computed + // attribute in the base tensor. + auto lhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + results[linear_base_index]); + auto rhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + update.getValue(update_index)); + auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs}); + if (new_value.size() != 1 || !new_value[0]) return {}; + results[linear_base_index] = + new_value[0].cast().getValue({}); + } while (next_index(update_index, update_type.getShape())); + + return DenseElementsAttr::get(base_type, results); +} + } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc new file mode 100644 index 00000000000..90da1251ea0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc @@ -0,0 +1,18 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td index b8b6cb80fba..bdb3e3cf490 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -18,15 +18,13 @@ limitations under the License. include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" -def EqualBinaryOperands : Constraint>; - // Canonicalization patterns. def DynamicBroadcastToOwnShape_1 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $arg0, - (Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), - (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + (HLO_DynamicBroadcastInDimOp:$op $x, + (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $attr), + (replaceWithValue $x)>; def DynamicBroadcastToOwnShape_2 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr), - (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), + (replaceWithValue $x)>; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc index 503b100c7ab..ca8c6a8d150 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc @@ -15,13 +15,15 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h" void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) { // clang-format off registry.insert(); + mlir::lmhlo_gpu::LmhloGpuDialect>(); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc new file mode 100644 index 00000000000..10c5c0c2f9d --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LMHLO GPU dialect. + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace lmhlo_gpu { + +LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" + >(); +} + +// TODO(jurahul): Add verification for operand shapes and ranks. + +} // namespace lmhlo_gpu +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc new file mode 100644 index 00000000000..cd2cfc58836 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc @@ -0,0 +1,18 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index cba0d3b4788..4524cf3ec1f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc new file mode 100644 index 00000000000..83dd4e62b47 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc @@ -0,0 +1,17 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a2a940e2c58..7a1b3f4f416 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -42,7 +42,7 @@ namespace mhlo { namespace { template -using BaseOpConversion = BufferAssignmentOpConversionPattern; +using BaseOpConversion = OpConversionPattern; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion { } }; +// This specialization exists so that LMHLO's Dot can be given a specific set of +// dimension numbers, when lowering from MHLO's Dot, which does not have +// dimension numbers (it uses DotGeneral for this generalized notion of dot +// products). When these two dialects are in sync with respect to the +// Dot/DotGeneral issue, this specialization should be deleted. +template <> +class HloToLhloOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotOp hloOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); + const auto& original_results = op->getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return failure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back( + InsertAlloc(op->getLoc(), result.value(), &rewriter)); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return failure(); + if (failed( + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), results_shape.front(), &rewriter)); + } + } + + // TODO(silvasean): Move this helper to MLIR core. + auto make_elements_attr = [&rewriter](ArrayRef integers) { + auto type = RankedTensorType::get({static_cast(integers.size())}, + rewriter.getIntegerType(64)); + return DenseIntElementsAttr::get(type, integers); + }; + auto dotOp = rewriter.create(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); + // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O]. + auto dimension_numbers = mhlo::DotDimensionNumbers::get( + make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}), + make_elements_attr({0}), rewriter.getContext()); + dotOp.dot_dimension_numbersAttr(dimension_numbers); + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + return success(); + } +}; + struct HloToLhloDynamicBroadcastInDimOpConverter : public BaseOpConversion { public: @@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter } }; +struct HloToLhloDotGeneralOpConverter + : public BaseOpConversion { + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp dotGeneralOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = dotGeneralOp.getOperation(); + + if (op->getResults().empty()) return failure(); + OpResult result = op->getResults()[0]; + RankedTensorType resultType = result.getType().dyn_cast(); + if (!resultType) return failure(); + + // The third buffer argument will be filled with what used to be the return + // type of the DotGeneral. + if (operands.size() != 2) return failure(); + std::array bufferArgs = {operands[0], operands[1], {}}; + + if (resultType.hasStaticShape()) { + bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + + bufferArgs[2] = InsertDynamicAllocAndDealloc( + op->getLoc(), result, results_shape.front(), &rewriter); + } + + rewriter.create(op->getLoc(), llvm::None, bufferArgs, + op->getAttrs()); + rewriter.replaceOp(op, bufferArgs[2]); + return success(); + } +}; + struct HloToLhloReduceOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -433,7 +524,7 @@ struct HloLegalizeToLhlo target.addLegalOp(); target.addIllegalDialect(); - BufferAssignmentTypeConverter converter; + BufferizeTypeConverter converter; auto isMemRefType = [](Type type) { return type.isa(); }; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); @@ -456,16 +547,16 @@ struct HloLegalizeToLhlo }); auto kind = results_escape_function - ? BufferAssignmentTypeConverter::KeepAsFunctionResult - : BufferAssignmentTypeConverter::AppendToArgumentsList; + ? BufferizeTypeConverter::KeepAsFunctionResult + : BufferizeTypeConverter::AppendToArgumentsList; converter.setResultConversionKind( kind); converter.setResultConversionKind(kind); populateHLOToLHLOConversionPattern(&context, &converter, &patterns); - populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, converter, - patterns); + populateWithBufferizeOpConversionPatterns( + &context, converter, patterns); populateShapeTypeConversionPatterns(&context, converter, patterns); if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); @@ -480,11 +571,12 @@ struct HloLegalizeToLhlo }; } // namespace -void populateHLOToLHLOConversionPattern( - MLIRContext* context, BufferAssignmentTypeConverter* converter, - OwningRewritePatternList* patterns) { +void populateHLOToLHLOConversionPattern(MLIRContext* context, + BufferizeTypeConverter* converter, + OwningRewritePatternList* patterns) { // clang-format off patterns->insert< + HloToLhloDotGeneralOpConverter, HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, @@ -531,7 +623,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter - >(context, *converter); + >(context); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 57859b64bed..b64d66200cf 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = + if (const mhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 6dc5b64a105..8f50ad0667f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" @@ -73,6 +74,24 @@ class LhloFuseLinalgPass result_buffers.insert(operand); } } + // Resolve aliasing operations (like casts) on the result to identify + // results. This only handles escaping results. + // TODO(herhut): Use BufferizeAliasAnalysis for this. + llvm::SmallVector worklist(result_buffers.begin(), + result_buffers.end()); + while (!worklist.empty()) { + Value result = worklist.pop_back_val(); + auto definingOp = result.getDefiningOp(); + if (!definingOp) { + continue; + } + if (auto viewLike = dyn_cast(definingOp)) { + auto alias = viewLike.getViewSource(); + if (result_buffers.insert(alias).second) { + worklist.push_back(alias); + } + } + } MLIRContext* ctx = func.getContext(); OpBuilder b(func); OperationFolder folder(ctx); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 2771afc6302..2041d22c62b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern { return failure(); } + // We don't currently support batching dimensions, or multiple contraction + // dimensions. + mhlo::DotDimensionNumbers dot_dimension_numbers = + op.dot_dimension_numbers(); + if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 || + dot_dimension_numbers.rhs_batching_dimensions().size() > 0) + return failure(); + if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 || + dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) { + return failure(); + } + LogicalResult map_status = success(); auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { SmallVector lhs_indices{ivs[0], ivs[2]}, diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 974585b12c5..7624ba929ea 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -81,6 +81,14 @@ func @remainder_fold_float() -> tensor<4xf32> { return %2 : tensor<4xf32> } +// CHECK-LABEL: round_fold +func @round_fold() -> tensor<4xf32> { + %0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32> + %1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> + // CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]> +} + // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> @@ -1167,3 +1175,291 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { // CHECK: mhlo.sqrt return %1 : tensor<4xf32> } + +// CHECK-LABEL: @tensor_flow_scatter_v1_update +func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_v2_update +func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[0]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_add +func @tensor_flow_scatter_add() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_repeated +func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_multiple_batch +func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd +func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] + // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector +func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] + // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @scatter_batch_dus +func @scatter_batch_dus() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<> : tensor<0xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1, 2]> : tensor<2xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @scatter_no_update_window_dim +func @scatter_no_update_window_dim() -> tensor<3xi32> { + %0 = constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> + %2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> +} + +// CHECK-LABEL: @scatter_negative_index +func @scatter_negative_index() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, -1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + +// CHECK-LABEL: @scatter_out_of_bound +func @scatter_out_of_bound() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 5]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + +// CHECK-LABEL: @pad_identity_fold +func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> { + %0 = constant dense<0.0> : tensor + %1 = "mhlo.pad"(%arg0, %0) { + edge_padding_low = dense<0> : tensor<2xi64>, + edge_padding_high = dense<0> : tensor<2xi64>, + interior_padding = dense<0> : tensor<2xi64> + } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: return %arg0 : tensor<5x7xf32> +} + +// CHECK-LABEL: @pad_fold +func @pad_fold() -> tensor<4x5xi32> { + %0 = constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> + %1 = constant dense<1> : tensor + %3 = "mhlo.pad"(%0, %1) { + edge_padding_low = dense<[1, 0]> : tensor<2xi64>, + edge_padding_high = dense<[1, 2]> : tensor<2xi64>, + interior_padding = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x2xi32>, tensor) -> tensor<4x5xi32> + return %3 : tensor<4x5xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] + // CHECK-SAME: ]> : tensor<4x5xi32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir index cc60217be65..64009437182 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s // CHECK-LABEL: func @func_op_unranked_arg_result func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 3caa4f0bd3b..fb106066d17 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s // BOTH-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -287,6 +287,28 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @gather +func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) { + %tensor_operand = tensor_load %operand : memref<13x7xf32> + %tensor_idxs = tensor_load %idxs : memref<5xi32> + %tensor_result = + "mhlo.gather"(%tensor_operand, %tensor_idxs) + { dimension_numbers = + { collapsed_slice_dims = dense<0> : tensor<1xi64> + , index_vector_dim = 1 : i64 + , offset_dims = dense<1> : tensor<1xi64> + , start_index_map = dense<0> : tensor<1xi64> } + , indices_are_sorted = false + , name = "gather.71" + , slice_sizes = dense<[1, 7]> : tensor<2xi64> } + : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32> + // BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<5x7xf32> + return +} + +// ----- + // BOTH-LABEL: func @imag_dyn func @imag_dyn(%operand: memref>, %result: memref) { %tensor_operand = tensor_load %operand : memref> @@ -511,7 +533,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { +// dot_dimension_numbers = { +// lhs_batching_dimensions = dense<> : tensor<0xi64>, +// lhs_contracting_dimensions = dense<1> : tensor<1xi64>, +// rhs_batching_dimensions = dense<> : tensor<0xi64>, +// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} +// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) @@ -632,4 +660,4 @@ func @shape_assuming_memref(%arg0: tensor) -> tensor { shape.assuming_yield %7 : tensor } return %2 : tensor -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir index 9a218b3657f..e51bdfec6f7 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-fuse-linalg.mlir @@ -3,7 +3,8 @@ // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %temp_result = alloc() : memref<6x6xf32> @@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, } %1 = alloc() : memref<100x10xf32> linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) outs(%1 : memref<100x10xf32>) { @@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, } dealloc %0 : memref<100x10xf32> linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%1 : memref<100x10xf32>) outs(%arg2 : memref<100x10xf32>) { @@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", + "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %temp_result = alloc() : memref<6x6x6x6xf32> @@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %temp_result = alloc() : memref<6x6xf32> @@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP: addf // PLOOP: linalg.generic // PLOOP: mulf + +// ----- + +func @view_result(%arg0: memref, %arg1: memref, %arg2: index) + -> memref<*xf32> { + %c1 = constant 1 : index + %c0 = constant 0 : index + %1 = alloc(%arg2) : memref + linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : memref) outs(%1 : memref) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %13 = absf %arg3 : f32 + linalg.yield %13 : f32 + } + %2 = lmhlo.reshape_memref_cast %1(%arg1) + : (memref, memref) -> memref<*xf32> + return %2 : memref<*xf32> +} + +// CHECK-LABEL: func @view_result +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: absf +// CHECK: reshape_memref_cast + +// TILED-LABEL: func @view_result +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: absf +// TILED: reshape_memref_cast + + +// PLOOP-LABEL: func @view_result +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: absf +// PLOOP: reshape_memref_cast + diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir index 87818045993..d020f7a083b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir @@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index debb035328c..47151089ccb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -621,10 +621,10 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 -// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 // CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 -// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 // CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 +// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 +// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 // CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 // CHECK-NEXT: linalg.yield %[[RESULT]] : i16 diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir new file mode 100644 index 00000000000..9e5ce67f39a --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @batch_norm_grad_memrefs +func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, + %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, + %grad_offset: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, + memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_inference_memrefs +func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { + "lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_training_memrefs +func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, + %batch_var: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @conv_forward +func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { + %scratch = alloc() : memref<32xi8> + // This defined a 2D convolution over a 8x8 single channel input using a 2x2 + // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) + "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) + { dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + kernel_input_feature_dimension = 0 : i64, + kernel_output_feature_dimension = 1 : i64, + kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>}, + window_strides = dense<[1, 1]> : tensor<2xi64>, + padding = dense<[0,0]> : tensor<2xi64>, + lhs_dilation = dense<[1,1]> : tensor<2xi64>, + rhs_dilation = dense<[1,1]> : tensor<2xi64>, + feature_group_count = 1, + batch_group_count = 1, + result_scale = 1.0, + backend_config = {algorithm=0, tensor_ops_enabled = true } + } + : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @gemm +func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () + return +} + + +// CHECK-LABEL: func @gemm_bias +func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, + %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + beta = 1.0, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () + return +} + +// CHECK-LABEL: func @cholesky +func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { + %scratch = alloc() : memref<32xi8> + %info = alloc() : memref<32xi32> + "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } + : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index 4462d9c45c6..fb4ab62371f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -328,6 +328,14 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor< // ----- +func @concat_0D(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { + // expected-error@+1 {{rank-0 values cannot be concatenated}} + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + // CHECK-LABEL: @concat_1D func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index d0c0e3c51e1..ed96dd5ffd8 100644 --- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir/InitAllDialects.h" @@ -31,6 +32,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 3f4f3997536..a98e83b7e1e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -326,6 +326,21 @@ static Optional GetTflitePoolParams(Operation* inst, namespace { +// Helper struct that wraps inputs/outputs of a single SignatureDef. +struct SignatureDefData { + // Note, we are using maps here to make order deterministic + // for easily testing only. + + // Inputs defined in the signature def mapped to tensor names. + std::map inputs; + // Outputs defined in the signature def mapped to tensor names. + std::map outputs; + // Method name exported by the signature def. + std::string method_name; + // SignatureDef key. + std::string signature_def_key; +}; + // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. class Translator { public: @@ -334,16 +349,19 @@ class Translator { // internal error. static Optional Translate( ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + bool emit_custom_ops, const std::unordered_set& tags, + OpOrArgNameMapper* op_or_arg_name_mapper); private: enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, OpOrArgNameMapper* op_or_arg_name_mapper) : module_(module), name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { + builder_(kInitialBufferSize), + saved_model_tags_(saved_model_tags) { // The first buffer must be empty according to the schema definition. empty_buffer_ = tflite::CreateBuffer(builder_); buffers_.push_back(empty_buffer_); @@ -450,6 +468,17 @@ class Translator { Optional>> CreateMetadataVector(); + // Builds and returns list of tfl.SignatureDef sections in the model. + Optional>> + CreateSignatureDefs(const std::vector& signature_defs); + + // Returns list of offsets for the passed 'items' in TensorMap structure + // inside the flatbuffer. + // 'items' is a map from tensor name in signatureDef to tensor name in + // the model. + std::vector> GetList( + const std::map& items); + // Uses the tf.entry_function attribute (if set) to initialize the op to name // mapping. void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); @@ -472,6 +501,8 @@ class Translator { BufferOffset empty_buffer_; std::vector> buffers_; + // Maps tensor name in the graph to the tensor index. + absl::flat_hash_map tensor_index_map_; // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. absl::flat_hash_map opcode_index_map_; @@ -490,6 +521,9 @@ class Translator { // The failed ops during legalization. std::set failed_flex_ops_; std::set failed_custom_ops_; + + // Set of saved model tags, if any. + const std::unordered_set saved_model_tags_; }; std::string Translator::UniqueName(mlir::Value val) { @@ -1131,6 +1165,7 @@ Optional> Translator::BuildSubGraph( } tensor_index_map.insert({value, tensors.size()}); + tensor_index_map_[name] = tensors.size(); auto tensor_or = BuildTensor(value, name, buffers_.size()); if (!tensor_or) return false; tensors.push_back(*tensor_or); @@ -1286,6 +1321,149 @@ Translator::CreateMetadataVector() { return builder_.CreateVector(metadata); } +// Helper method that returns list of all strings in a StringAttr identified +// by 'attr_key' and values are separated by a comma. +llvm::SmallVector GetStringsFromAttrWithSeparator( + mlir::DictionaryAttr attr, const std::string& attr_key) { + llvm::SmallVector result; + if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + str.getValue().split(result, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + } + return result; +} + +// Helper method that return list of string for all the StringAttr in the +// Attribute identified by 'attr_name'. +std::vector GetStringsFromDictionaryAttr( + const llvm::SmallVector& dict_attrs, + const std::string& attr_name) { + std::vector result; + for (const auto& arg_attr : dict_attrs) { + auto attrs = arg_attr.getAttrs(); + for (const auto attr : attrs) { + if (attr.first.str() == attr_name) { + auto array_attr = attr.second.dyn_cast_or_null(); + if (!array_attr || array_attr.empty()) continue; + auto string_attr = array_attr[0].dyn_cast_or_null(); + if (!string_attr) continue; + result.push_back(string_attr.getValue().str()); + } + } + } + return result; +} + +std::vector BuildSignaturedef( + FuncOp main_op, const std::string& saved_model_tag) { + static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path"; + static const char kEntryFunctionAttributes[] = "tf.entry_function"; + + // Fetch inputs and outputs from the signature. + llvm::SmallVector arg_attrs, res_attrs; + main_op.getAllArgAttrs(arg_attrs); + main_op.getAllResultAttrs(res_attrs); + std::vector sig_def_inputs = + GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath); + std::vector sig_def_outputs = + GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath); + + // If no defined saved model signature, then return empty list. + // This can happen when we are converting model not from SavedModel. + if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {}; + + // Fetch function inputs and outputs tensor names. + auto dict_attr = + main_op.getAttrOfType(kEntryFunctionAttributes); + if (!dict_attr) return {}; + + // Get Input and output tensor names from attribute. + llvm::SmallVector input_names = + GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs"); + llvm::SmallVector output_names = + GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs"); + + // Verify input size match the number of arguments. + if (input_names.size() != main_op.getNumArguments()) { + main_op.emitWarning() << "invalid entry function specification"; + return {}; + } + // Verify output size match the number of arguments. + auto term = main_op.back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + main_op.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" + << term->getNumOperands() << ")"; + return {}; + } + // Verify number of tensors for inputs and outputs matches size + // of the list in the signature def. + if (input_names.size() != sig_def_inputs.size() || + output_names.size() != sig_def_outputs.size()) { + main_op.emitWarning( + "Mismatch between signature def inputs/outputs and main function " + "arguments."); + return {}; + } + // Exported method name. + auto exported_name = + main_op.getAttrOfType("tf_saved_model.exported_names"); + if (exported_name.empty()) { + main_op.emitError("Empty exported names for main Function"); + return {}; + } + // Fill the SignatureDefData container. + // We create vector of size 1 as TFLite now supports only 1 signatureDef. + std::vector result(1); + for (int i = 0; i < input_names.size(); ++i) { + result[0].inputs[sig_def_inputs[i]] = input_names[i].str(); + } + for (int i = 0; i < output_names.size(); ++i) { + result[0].outputs[sig_def_outputs[i]] = output_names[i].str(); + } + if (auto name_attr = exported_name[0].dyn_cast_or_null()) + result[0].method_name = name_attr.getValue().str(); + result[0].signature_def_key = saved_model_tag; + return result; +} + +std::vector> Translator::GetList( + const std::map& items) { + std::vector> result; + for (const auto& item : items) { + auto name_buf = builder_.CreateString(item.first); + tflite::TensorMapBuilder tensor_map_builder(builder_); + tensor_map_builder.add_name(name_buf); + tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]); + result.push_back(tensor_map_builder.Finish()); + } + return result; +} + +Optional>> +Translator::CreateSignatureDefs( + const std::vector& signature_defs) { + std::vector> signature_defs_buffer; + for (const auto& signature_def_data : signature_defs) { + auto inputs = GetList(signature_def_data.inputs); + auto outputs = GetList(signature_def_data.outputs); + auto inputs_buf = builder_.CreateVector(inputs); + auto outputs_buf = builder_.CreateVector(outputs); + auto method_name_buf = + builder_.CreateString(signature_def_data.method_name); + auto signature_def_key_buf = + builder_.CreateString(signature_def_data.signature_def_key); + tflite::SignatureDefBuilder sig_def_builder(builder_); + sig_def_builder.add_inputs(inputs_buf); + sig_def_builder.add_outputs(outputs_buf); + sig_def_builder.add_method_name(method_name_buf); + sig_def_builder.add_key(signature_def_key_buf); + signature_defs_buffer.push_back(sig_def_builder.Finish()); + } + + return builder_.CreateVector(signature_defs_buffer); +} + bool UpdateEntryFunction(ModuleOp module) { if (module.lookupSymbol("main") != nullptr) { // We already have an entry function. @@ -1312,11 +1490,12 @@ bool UpdateEntryFunction(ModuleOp module) { Optional Translator::Translate( ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + bool emit_custom_ops, const std::unordered_set& tags, + OpOrArgNameMapper* op_or_arg_name_mapper) { if (!UpdateEntryFunction(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None; Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); + emit_custom_ops, tags, op_or_arg_name_mapper); return translator.TranslateInternal(); } @@ -1392,10 +1571,17 @@ Optional Translator::TranslateInternal() { auto metadata = CreateMetadataVector(); if (!metadata) return llvm::None; - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); + // Build SignatureDef + // We only have 1 entry point 'main' function, so build only 1 signature def. + auto main_fn_signature_def = BuildSignaturedef( + main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin()); + auto signature_defs = CreateSignatureDefs(main_fn_signature_def); + + auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, + builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), + description, builder_.CreateVector(buffers_), + metadata_buffer, *metadata, *signature_defs); tflite::FinishModelBuffer(builder_, model); tflite::UpdateOpVersion(builder_.GetBufferPointer()); tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); @@ -1519,12 +1705,10 @@ bool tflite::MlirToFlatBufferTranslateFunction( ModuleOp module, std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, + op_or_arg_name_mapper); } bool tflite::MlirToFlatBufferTranslateFunction( @@ -1534,5 +1718,30 @@ bool tflite::MlirToFlatBufferTranslateFunction( OpOrArgLocNameMapper op_or_arg_name_mapper; return MlirToFlatBufferTranslateFunction( module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); + emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, + &op_or_arg_name_mapper); +} + +bool tflite::MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, saved_model_tags, + &op_or_arg_name_mapper); +} + +bool tflite::MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = Translator::Translate( + module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, + saved_model_tags, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h index 0fbf2f07dfb..0888d2a4a41 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #include +#include #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" @@ -33,11 +34,24 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, bool emit_select_tf_ops, bool emit_custom_ops); +// Same as above but takes SavedModel tags of the model. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags); + // Same as the above but with a custom op name mapper. bool MlirToFlatBufferTranslateFunction( mlir::ModuleOp module, std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); + +// Same as above but takes SavedModel tags of the model. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + const std::unordered_set& saved_model_tags, + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f7ee323957d..21cbf518967 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [ let hasFolder = 1; } +def TFL_CumsumOp: TFL_Op<"cumsum", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult, + TFL_OperandHasRank<1, 0>]> { + let summary = "Cumsum operator"; + + let description = [{ + Compute the cumulative sum of the tensor x along axis. + }]; + + let arguments = ( + ins TFL_TensorOf<[F32, I32, I64]>:$input, + TFL_I32Tensor:$axis, + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output); + + let hasOptions = 1; +} + def TFL_DepthwiseConv2DOp : TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { let arguments = ( diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e786bedc86d..005c5123906 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -90,9 +90,10 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.lower_tensor_list_ops = true; - return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - pass_config, result, - /*session=*/llvm::None); + return internal::ConvertMLIRToTFLiteFlatBuffer( + toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, + result, + /*session=*/llvm::None); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index aabc54335d1..7bbd3209dfe 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -177,7 +177,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( - toco_flags, std::move(module), pass_config, result, + toco_flags, std::move(module), pass_config, tags, result, /*session=*/llvm::None); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a4e58123e05..ae2454dcf1e 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -273,7 +273,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { Status ConvertMLIRToTFLiteFlatBuffer( const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, string* result, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, string* result, llvm::Optional session) { bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); @@ -297,8 +298,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result, - &pm); + emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, + saved_model_tags, result, &pm); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index d79bdc6df67..d4f9e739121 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ #include +#include #include #include "llvm/ADT/Optional.h" @@ -48,7 +49,8 @@ Status PopulateQuantizationSpecs( // This will also run relevant passes as well. Status ConvertMLIRToTFLiteFlatBuffer( const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, string* result, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, string* result, llvm::Optional session); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt index facd6005e7d..9f8d82eb184 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/add.pbtxt @@ -96,5 +96,6 @@ versions { # CHECK-NEXT: metadata: [ { # CHECK-NEXT: name: "min_runtime_version", # CHECK-NEXT: buffer: 4 -# CHECK-NEXT: } ] +# CHECK-NEXT: } ], +# CHECK-NEXT: signature_defs: [ ] # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index 41fbbbcb9c5..e579aea558e 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -54,6 +54,7 @@ tf_native_cc_binary( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], @@ -70,6 +71,7 @@ tf_native_cc_binary( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_legacy_reshape.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_legacy_reshape.cc index 3d4f440efe6..f5b73207157 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_legacy_reshape.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_legacy_reshape.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" using llvm::Optional; using llvm::cl::opt; @@ -95,7 +96,8 @@ Optional> RemoveConstantOpInReshape( // Find the reshape ops and make it single operand. for (auto& sub_graph : model->subgraphs) { for (auto& op : sub_graph->operators) { - if (model->operator_codes[op->opcode_index]->builtin_code == + if (tflite::GetBuiltinCode( + model->operator_codes[op->opcode_index].get()) == tflite::BuiltinOperator_RESHAPE) { auto& output_tensor = sub_graph->tensors[op->outputs[0]]; auto shape = output_tensor->shape; diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc index 1349622eefc..c1f2417cdb5 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" using llvm::Optional; using llvm::cl::opt; @@ -114,7 +115,8 @@ Optional> InjectStatsToFullyConnected( // Find the tensors and inject the min and max to the input and output for (auto& sub_graph : model->subgraphs) { for (auto& op : sub_graph->operators) { - if (model->operator_codes[op->opcode_index]->builtin_code == + if (tflite::GetBuiltinCode( + model->operator_codes[op->opcode_index].get()) == tflite::BuiltinOperator_FULLY_CONNECTED) { // inject min/max to the input and output tensors auto& input_tensor = sub_graph->tensors[op->inputs[0]]; diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir index 138614d81e6..d56c2cc221a 100644 --- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -3442,8 +3442,8 @@ func @sgnn_projection(%arg0: tensor {tf._user_specified_name = "va %0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64> %1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor) -> tensor %2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor, tensor<10x1xi64>) -> tensor<10x?xf64> - %3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> - %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor + %3 = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi64>} : () -> tensor<2xi64> + %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<2xi64>) -> tensor return %4 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 4de278ee324..97a496a0b89 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1361,7 +1361,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK-LABEL: conv2d_backprop_input // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> - // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> + // CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[CST_0:.*]] = constant unit // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> @@ -1587,10 +1588,31 @@ func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { // CHECK: "tfl.transpose" } -func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> { +func @tranpose_arg32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> { %0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> - // CHECK-LABEL: tranpose_arg + // CHECK-LABEL: tranpose_arg32 // CHECK: "tfl.transpose" } +func @tranpose_arg64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor<3x2xf32> { + %0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> + // CHECK-LABEL: tranpose_arg64 + // CHECK: "tfl.transpose" +} + +func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> { + %0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK-LABEL: cumsum + // CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> +} + +func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> { + %0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK-LABEL: cumsum_invalid + // CHECK-NOT: "tfl.cumsum" +} + diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir index f5e5087d420..b2f684e6be8 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir @@ -116,6 +116,7 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} ^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>): diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 02767ddceba..a067826f86d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -100,6 +100,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir index a576c84e207..ef82175a47d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir @@ -91,6 +91,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir index 7d2d84d242a..f4bc10b2fe2 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir @@ -93,6 +93,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir index 66ec5ed8a04..f7ff99b117d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir @@ -97,6 +97,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir index dc8590b4a20..9aca1ecb47d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir @@ -54,6 +54,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir index 245260d994d..b2d7f611ede 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir @@ -47,6 +47,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir index 82ac52d2a64..b8749b4b76c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir @@ -60,6 +60,7 @@ func @main(tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> loc("add") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir index 543ca7fc3e7..c8f3949500e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_f64.mlir @@ -60,6 +60,7 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir index e15c7f23585..059cfc0d54e 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir @@ -99,6 +99,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir index 58f693c1b0a..b01bafe4ea7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir @@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir index 913a69c4d46..95bcc1547f7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir @@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 86794afdf4c..c89239c2e6f 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -166,6 +166,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 11 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir index d24fa33fa13..f32fe880121 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir @@ -87,6 +87,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index 19bfc661425..017870ca334 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -258,6 +258,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir index a48e3c82e9d..10332e45bec 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir @@ -320,5 +320,6 @@ func @main(%arg0: tensor<1x528x!quant.uniform> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 23 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir index c4cb910f17b..eeca4267524 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir @@ -140,6 +140,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 8 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir index 49d71f24d2d..3fb00cf6024 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata.mlir @@ -33,4 +33,5 @@ module attributes { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir index 04ceb3855f2..c8af68a190d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir @@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir index 0206724f8ad..441dbd8f925 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v3.mlir @@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir index 455a7695d48..ec0fd07c25a 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/nn.mlir @@ -55,6 +55,7 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] + // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir index e39ded18b86..60360c7ded6 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir @@ -48,6 +48,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 3 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform>) -> tensor<4xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir index 81065798271..93581e54f10 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir @@ -165,6 +165,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 10 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %0 = "tfl.pseudo_const" () {value = dense<[1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir index 129e037a2ee..af59475f6a1 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir @@ -59,6 +59,7 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> { // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 4 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<[6]> : tensor<1xi32>} : () -> tensor<1xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir new file mode 100644 index 00000000000..b9866b4696d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir @@ -0,0 +1,117 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 9, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: FULLY_CONNECTED +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 1, 384 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "serving_default_input2:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 384 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 384 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "serving_default_input1:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 384 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "std.constant", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5, 384 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "std.constant1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 5, 384 ], +// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: name: "std.constant2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 5 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "StatefulPartitionedCall:0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 5 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 5 ], +// CHECK-NEXT: buffer: 7, +// CHECK-NEXT: name: "StatefulPartitionedCall:1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: shape_signature: [ -1, 5 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 6, 5 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 3, 2 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: builtin_options_type: FullyConnectedOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: inputs: [ 0, 4, 2 ], +// CHECK-NEXT: outputs: [ 6 ], +// CHECK-NEXT: builtin_options_type: FullyConnectedOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", + +// CHECK: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 8 +// CHECK-NEXT: } ], +// CHECK-NEXT: signature_defs: [ { +// CHECK-NEXT: inputs: [ { +// CHECK-NEXT: name: "input1", +// CHECK-NEXT: tensor_index: 1 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "input2" +// CHECK-NEXT: } ], +// CHECK-NEXT: outputs: [ { +// CHECK-NEXT: name: "end_logits", +// CHECK-NEXT: tensor_index: 5 +// CHECK-NEXT: }, { +// CHECK-NEXT: name: "start_logits", +// CHECK-NEXT: tensor_index: 6 +// CHECK-NEXT: } ], +// CHECK-NEXT: method_name: "serving_default", +// CHECK-NEXT: key: "" +// CHECK-NEXT: } ] +// CHECK-NEXT:} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 554 : i32}, tf_saved_model.semantics} { + func @main(%arg0: tensor {tf_saved_model.index_path = ["input2"]}, %arg1: tensor {tf_saved_model.index_path = ["input1"]}) -> (tensor {tf_saved_model.index_path = ["start_logits"]}, tensor {tf_saved_model.index_path = ["end_logits"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input2:0,serving_default_input1:0", outputs = "StatefulPartitionedCall:1,StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = constant dense<0.000000e+00> : tensor<5xf32> + %cst_0 = constant dense<1.0> : tensor<5x384xf32> + %cst_1 = constant dense<1.0> : tensor<5x384xf32> + %0 = "tfl.fully_connected"(%arg0, %cst_0, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor, tensor<5x384xf32>, tensor<5xf32>) -> tensor + %1 = "tfl.fully_connected"(%arg0, %cst_1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor, tensor<5x384xf32>, tensor<5xf32>) -> tensor + return %1, %0 : tensor, tensor + } +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir index ea380f8f47d..fd0e0386c46 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir @@ -105,6 +105,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 6 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } %0 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir index e2ad4c73baa..3f48facd122 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir @@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir index c1d30a9b4d4..1b6b66ed087 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir @@ -88,6 +88,7 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir index 87e3ccf4688..68b21765717 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir @@ -199,6 +199,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir index 5252dc21a59..1256dd19036 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir @@ -70,6 +70,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT:} %cst = constant unit diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 38c1ed40b35..ffb5b2c781b 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -257,6 +257,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 26 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 575499c8d66..5b29c1ff050 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 7 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index e58c54219ab..51935676eed 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -199,6 +199,7 @@ // CHECK-NEXT: name: "min_runtime_version", // CHECK-NEXT: buffer: 14 // CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } func @main(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor<1xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 154e4fa8e0e..bedf77f726a 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -26,6 +26,26 @@ func @fusedDepthwiseConv2dRelu6(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16 // CHECK: return %0 } +// CHECK-LABEL: fusedMaxPool2dRelu +func @fusedMaxPool2dRelu(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> { + %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + %1 = "tfl.relu"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> + return %1 : tensor<1x73x73x16xf32> + + // CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: return %0 +} + +// CHECK-LABEL: fusedAvgPool2dRelu1 +func @fusedAvgPool2dRelu1(%arg0: tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> { + %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x73x73x16xf32>) -> tensor<1x73x73x16xf32> + return %1 : tensor<1x73x73x16xf32> + + // CHECK: %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "RELU_N1_TO_1", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x16xf32>) -> tensor<1x73x73x16xf32> + // CHECK: return %0 +} + // CHECK-LABEL: fuseAddIntoConv2d func @fuseAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> { %cst = constant dense<1.5> : tensor<16xf32> @@ -971,6 +991,16 @@ func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %[[RESULT]] } +// CHECK-LABEL: Relu_bf16 +func @Relu_bf16(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { + %cst = constant dense<0.0> : tensor<2x3xbf16> + %0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xbf16>, tensor<2x3xbf16>) -> tensor<2x3xbf16> + return %0 : tensor<2x3xbf16> + + // CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0) + // CHECK: return %[[RESULT]] +} + // CHECK-LABEL: Relu1 func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { %cst = constant dense<-1.0> : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 4e7c08945d4..186c8631e56 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -62,40 +62,6 @@ func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> // LAYOUT: "tfl.conv_2d" } - -func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): - // OK - %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Unsupported training - %1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Use other output - %2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - - return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> - -// CHECK-LABEL: fusedBatchNorm -// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> -// variance + epsilon -// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) -// rsqrt(variance + epsilon) -// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]]) -// scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]]) -// x * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]]) -// mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]]) -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]]) -// x * scale * rsqrt(variance + epsilon) + -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) - -// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -} - func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { ^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): // OK diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 0835e3fdefa..aa3545d9beb 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -143,6 +143,7 @@ int main(int argc, char **argv) { mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); StatusOr module; + std::unordered_set tags; tensorflow::GraphImportConfig specs; specs.upgrade_legacy = upgrade_legacy; @@ -161,8 +162,7 @@ int main(int argc, char **argv) { module = tensorflow::errors::InvalidArgument( "Importing saved model should not have input_mlir set"); - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); + tags = absl::StrSplit(saved_model_tags, ','); std::vector exported_names_vector = absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); @@ -241,7 +241,7 @@ int main(int argc, char **argv) { std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, quant_specs, &result, &pm); + emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm); if (!status.ok()) return kTrFailure; std::string error_msg; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index a2387e89483..622e32c2766 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -137,8 +137,9 @@ StatusOr LoadFromGraphdefOrMlirSource( Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, - mlir::PassManager* pass_manager) { + const mlir::TFL::QuantizationSpecs& quant_specs, + const std::unordered_set& saved_model_tags, + std::string* result, mlir::PassManager* pass_manager) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -171,7 +172,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( if (!quant_specs.RunWeightQuantization()) { if (tflite::MlirToFlatBufferTranslateFunction( module, result, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops)) { + emit_custom_ops, saved_model_tags)) { return statusHandler.ConsumeStatus(); } } else { @@ -180,7 +181,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( std::string pre_quantized_result; if (tflite::MlirToFlatBufferTranslateFunction( module, &pre_quantized_result, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops)) { + emit_select_tf_ops, emit_custom_ops, saved_model_tags)) { return statusHandler.ConsumeStatus(); } flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index ec2f9e10d26..95b6097e1eb 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -63,8 +63,9 @@ stream_executor::port::StatusOr ImportSavedModel( Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, - mlir::PassManager* pass_manager); + const mlir::TFL::QuantizationSpecs& quant_specs, + const std::unordered_set& saved_model_tags, + std::string* result, mlir::PassManager* pass_manager); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 322da815a47..85df2417ef2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall< "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; // Converts tensor with int64 to int32. -def CreateCastToInt32 : NativeCodeCall< +def CreateTFLCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; // Checks whether the given operation has static shapes and same shapes of all inputs. @@ -216,12 +216,9 @@ def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; -def LegalizeTransposeInt64 : Pat< - (TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)), - (TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>; - def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm), - (TFL_TransposeOp $arg, $perm)>; + (TFL_TransposeOp $arg, + (CreateTFLCastToInt32Op $perm))>; def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; @@ -444,3 +441,7 @@ def LegalizeMatrixSetDiag : Pat< def LegalizeScatterNd : Pat< (TF_ScatterNdOp I32Tensor:$indices, $updates, $shape), (TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>; + +def LegalizeCumsum : Pat< + (TF_CumsumOp $input, $axis, $exclusive, $reverse), + (TFL_CumsumOp $input, $axis, $exclusive, $reverse)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 13c7a08a094..6685045d59f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -118,14 +118,11 @@ bool HasSameStaticShapes(Operation* op) { } // Util that casts 'val' to Int32 by adding a cast Op. -Value CreateCastToInt32(Attribute val, Location loc, - PatternRewriter& rewriter) { +Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { auto shape = val.getType().dyn_cast().getShape(); IntegerType new_ele_type = rewriter.getIntegerType(32); ShapedType new_type = RankedTensorType::get(shape, new_ele_type); - return rewriter.create(loc, new_type, - rewriter.create(loc, val), - rewriter.getBoolAttr(false)); + return rewriter.create(loc, new_type, val); } #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc" diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 8243ed2a620..57925663d74 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -57,15 +57,31 @@ multiclass FuseActFnIntoConvOpPat { [(HasOneUse $conv_out)]>; } +multiclass FuseActFnIntoPoolOpPat { + def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height, + $filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)), + (TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding, + $stride_h, $stride_w, ActFnAttr), + [(HasOneUse $pool_out)]>; + def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h, + $filter_width, $filter_height, TFL_AF_None)), + (TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h, + $filter_width, $filter_height, ActFnAttr), + [(HasOneUse $pool_out)]>; +} + // TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused // activation functions. // Currently we're not fusing tanh, sigmoid, hard_swish and other activations // those cannot be simply translated into clamping. foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], [TFL_Relu6Op, TFL_AF_Relu6], - [TFL_Relu1Op, TFL_AF_Relu1]] in + [TFL_Relu1Op, TFL_AF_Relu1]] in { defm : FuseActFnIntoConvOpPat; - + defm : FuseActFnIntoPoolOpPat; +} class CanFuseConvOrDepthwiseConv : Constraint< CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; @@ -222,8 +238,6 @@ def eliminate_dq_q_pairs : Pat< [(NotFromQuantOpOrSameQuantType $in, $qt)]>; -// Constraint that makes sure both operands are the same operands. -def EqualOperands : Constraint>; // Checks if the operand has rank == n @@ -235,28 +249,26 @@ def MatchHardSwishPattern1 : Pat< (TFL_MulOp (TFL_MulOp $x, (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), TFL_AF_None), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; def MatchHardSwishPattern2 : Pat< (TFL_MulOp $x, (TFL_MulOp (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; // Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to // incorrect placement in the quantization aware training. @@ -265,14 +277,13 @@ def MatchHardSwishQuantized : Pat< (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp (TFL_MulOp $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp - $y, + $x, (ConstantOp ConstantAttr, "3.0f">), TFL_AF_Relu6), $qattr2)), TFL_AF_None), $qattr1)), (ConstantOp ConstantAttr, "0.166666666f">), TFL_AF_None), - (TFL_HardSwishOp $x), - [(EqualOperands $x, $y)]>; + (TFL_HardSwishOp $x)>; // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< @@ -293,47 +304,44 @@ multiclass L2NormalizePatterns { // Mul->Rsqrt->Sum->Square Or // Div->sqrt->Sum->Square def L2NormalizePattern1#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims)), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis)]>; + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis)]>; // Below patterns for L2Normalize when there is an Add or Maximum // adding or clamping to a small constant scalar. def L2NormalizePattern2#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_AddOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims), (ConstantOp $epsilon), TFL_AF_None)), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis), + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis), (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; def L2NormalizePattern3#FirstOp#SecondOp : Pat< - (FirstOp $operand1, + (FirstOp $x, (SecondOp (TFL_MaximumOp (TFL_SumOp - (TFL_SquareOp:$sq_op $square_operand), + (TFL_SquareOp:$sq_op $x), (ConstantOp I32ElementsAttr:$axis), $keep_dims), (ConstantOp $epsilon))), TFL_AF_None), - (TFL_L2NormalizationOp $operand1, TFL_AF_None), - [(EqualOperands $operand1, $square_operand), - (L2NormValidReduceIndex $sq_op, $axis), + (TFL_L2NormalizationOp $x, TFL_AF_None), + [(L2NormValidReduceIndex $sq_op, $axis), (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; } @@ -481,9 +489,9 @@ def ConvertExpandDimsToReshape : Pat< [(AnyStaticShapeTensor $expand_dims_op)]>; class FloatValueEquals : Constraint().getNumElements() == 1 &&" - "$0.isa() &&" - "*$0.cast().getValues().begin() == " # val>>; + "$0.isa() && " + "llvm::all_of($0.cast().getFloatValues(), " + "[](const APFloat& f) { return f.isExactlyValue(" # val # "); })">>; // ReLU patterns def MatchReluPattern : Pat< @@ -505,12 +513,11 @@ def MatchRelu1Pattern2 : Pat< def MatchLeakyRelu : Pat< (TFL_MaximumOp - (TFL_MulOp:$mul_out $input1, + (TFL_MulOp:$mul_out $x, (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), - $input2), - (TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha), + $x), + (TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha), [(ConstDoubleValueLessThan<"1"> $alpha), - (EqualOperands $input1, $input2), (HasOneUse $mul_out)]>; def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input), @@ -526,15 +533,14 @@ def PReluAlphaRankCheck : Constraint< // f(x) = Relu(x) + (-alpha * Relu(-x)) def MatchPRelu : Pat< (TFL_AddOp - (TFL_ReluOp:$relu_out $input1), + (TFL_ReluOp:$relu_out $x), (TFL_MulOp:$mul_out - (TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)), + (TFL_ReluOp (TFL_NegOp:$input_neg_out $x)), $neg_alpha, TFL_AF_None), TFL_AF_None), - (TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)), - [(EqualOperands $input1, $input2), - (PReluAlphaRankCheck $neg_alpha, $input1), + (TFL_PReluOp $x, (TFL_NegOp $neg_alpha)), + [(PReluAlphaRankCheck $neg_alpha, $x), (HasOneUse $relu_out), (HasOneUse $mul_out), (HasOneUse $input_neg_out)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index ecca3d38deb..c4f30c22be3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -740,31 +740,6 @@ struct ConvertTFBroadcastTo : public RewritePattern { } }; -struct ConvertFusedBatchNorm : public OpRewritePattern { - explicit ConvertFusedBatchNorm(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, - PatternRewriter &rewriter) const override { - auto new_result_types = - llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); - // reserve_space_3 - new_result_types.push_back( - UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); - - OperationState new_state(tf_fused_batch_norm_op.getLoc(), - TF::FusedBatchNormV3Op::getOperationName(), - tf_fused_batch_norm_op.getOperands(), - new_result_types, - tf_fused_batch_norm_op.getAttrs()); - Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); - - rewriter.replaceOp(tf_fused_batch_norm_op, - tf_fused_batch_norm_op_v3->getResults().drop_back()); - return success(); - } -}; - // The below pattern is equivalent to the DRR rule below // The checks are dependent on generated values, so we can't add // the checks on intermediate values, ideally we should find equivalent @@ -1202,7 +1177,6 @@ void PrepareTFPass::runOnFunction() { patterns.insert, FusedBatchNormV3Pat, ConvertTFDilatedConvOp>(ctx); - patterns.insert(ctx); TFL::populateWithGenerated(ctx, patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1c740731acd..b09021e8689 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -841,6 +841,7 @@ cc_library( "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_outline_tpu_island.cc", + "transforms/fold_broadcast.cc", "transforms/fold_switch.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/functional_control_flow_to_regions.cc", @@ -930,6 +931,7 @@ cc_library( ":shape_inference_utils", ":tensorflow", ":tensorflow_analysis", + ":tensorflow_ops", ":tensorflow_optimize_inc_gen", ":tensorflow_types", ":tf_data_optimization", @@ -957,6 +959,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 362c0081c77..8631d22694e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1041,7 +1041,7 @@ beta function. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect, TF_ContractionFusableInterface]> { +def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect, TF_ContractionFusableInterface, TF_LayoutSensitiveInterface]> { let summary = "Adds `bias` to `value`."; let description = [{ @@ -1065,6 +1065,11 @@ Broadcasting is supported, so `value` may have any number of dimensions. let extraClassDeclaration = [{ // TF_ContractionFusableInterface: Optional GetContractionFusion(); + // TF_LayoutSensitiveInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); }]; let verifier = [{ @@ -2799,6 +2804,23 @@ Converts the given variant tensor to an iterator and stores it in the given reso let results = (outs); } +def TF_DestroyResourceOp : TF_Op<"DestroyResourceOp", []> { + let summary = "Deletes the resource specified by the handle."; + + let description = [{ +All subsequent operations using the resource will result in a NotFound +error status. + }]; + + let arguments = (ins + TF_ResourceTensor:$resource, + + DefaultValuedAttr:$ignore_lookup_error + ); + + let results = (outs); +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -3925,6 +3947,8 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let hasCanonicalizer = 1; + let verifier = [{ return Verify(*this); }]; @@ -13204,6 +13228,175 @@ num_elements: optional. If not -1, the number of elements in the list. }]; } +def TF_TensorScatterAddOp : TF_Op<"TensorScatterAdd", [NoSideEffect]> { + let summary = [{ +Adds sparse `updates` to an existing tensor according to `indices`. + }]; + + let description = [{ +This operation creates a new tensor by adding sparse `updates` to the passed +in `tensor`. +This operation is very similar to `tf.scatter_nd_add`, except that the updates +are added onto an existing tensor (as opposed to a variable). If the memory +for the existing tensor cannot be re-used, a copy is made and updated. + +`indices` is an integer tensor containing indices into a new tensor of shape +`tensor.shape`. The last dimension of `indices` can be at most the rank of +`tensor.shape`: + + indices.shape[-1] <= tensor.shape.rank + +The last dimension of `indices` corresponds to indices into elements +(if `indices.shape[-1] = tensor.shape.rank`) or slices +(if `indices.shape[-1] < tensor.shape.rank`) along dimension +`indices.shape[-1]` of `tensor.shape`. `updates` is a tensor with shape + + indices.shape[:-1] + tensor.shape[indices.shape[-1]:] + +The simplest form of tensor_scatter_add is to add individual elements to a +tensor by index. For example, say we want to add 4 elements in a rank-1 +tensor with 8 elements. + +In Python, this scatter add operation would look like this: + +```python + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + tensor = tf.ones([8], dtype=tf.int32) + updated = tf.tensor_scatter_nd_add(tensor, indices, updates) + print(updated) +``` + +The resulting tensor would look like this: + + [1, 12, 1, 11, 10, 1, 1, 13] + +We can also, insert entire slices of a higher rank tensor all at once. For +example, if we wanted to insert two slices in the first dimension of a +rank-3 tensor with two matrices of new values. + +In Python, this scatter add operation would look like this: + +```python + indices = tf.constant([[0], [2]]) + updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]]]) + tensor = tf.ones([4, 4, 4],dtype=tf.int32) + updated = tf.tensor_scatter_nd_add(tensor, indices, updates) + print(updated) +``` + +The resulting tensor would look like this: + + [[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], + [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], + [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], + [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] + +Note that on CPU, if an out of bound index is found, an error is returned. +On GPU, if an out of bound index is found, the index is ignored. + }]; + + let arguments = (ins + TF_Tensor:$tensor, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_TensorScatterSubOp : TF_Op<"TensorScatterSub", [NoSideEffect]> { + let summary = [{ +Subtracts sparse `updates` from an existing tensor according to `indices`. + }]; + + let description = [{ +This operation creates a new tensor by subtracting sparse `updates` from the +passed in `tensor`. +This operation is very similar to `tf.scatter_nd_sub`, except that the updates +are subtracted from an existing tensor (as opposed to a variable). If the memory +for the existing tensor cannot be re-used, a copy is made and updated. + +`indices` is an integer tensor containing indices into a new tensor of shape +`shape`. The last dimension of `indices` can be at most the rank of `shape`: + + indices.shape[-1] <= shape.rank + +The last dimension of `indices` corresponds to indices into elements +(if `indices.shape[-1] = shape.rank`) or slices +(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +`shape`. `updates` is a tensor with shape + + indices.shape[:-1] + shape[indices.shape[-1]:] + +The simplest form of tensor_scatter_sub is to subtract individual elements +from a tensor by index. For example, say we want to insert 4 scattered elements +in a rank-1 tensor with 8 elements. + +In Python, this scatter subtract operation would look like this: + +```python + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + tensor = tf.ones([8], dtype=tf.int32) + updated = tf.tensor_scatter_nd_sub(tensor, indices, updates) + print(updated) +``` + +The resulting tensor would look like this: + + [1, -10, 1, -9, -8, 1, 1, -11] + +We can also, insert entire slices of a higher rank tensor all at once. For +example, if we wanted to insert two slices in the first dimension of a +rank-3 tensor with two matrices of new values. + +In Python, this scatter add operation would look like this: + +```python + indices = tf.constant([[0], [2]]) + updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]]]) + tensor = tf.ones([4, 4, 4],dtype=tf.int32) + updated = tf.tensor_scatter_nd_sub(tensor, indices, updates) + print(updated) +``` + +The resulting tensor would look like this: + + [[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], + [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], + [[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], + [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] + +Note that on CPU, if an out of bound index is found, an error is returned. +On GPU, if an out of bound index is found, the index is ignored. + }]; + + let arguments = (ins + TF_Tensor:$tensor, + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TensorScatterUpdateOp : TF_Op<"TensorScatterUpdate", [NoSideEffect]> { let summary = [{ Scatter `updates` into an existing tensor according to `indices`. @@ -13456,6 +13649,33 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: let hasFolder = 1; } +def TF_TridiagonalSolveOp : TF_Op<"TridiagonalSolve", [NoSideEffect]> { + let summary = "Solves tridiagonal systems of equations."; + + let description = [{ +Solves tridiagonal systems of equations. + Supports batch dimensions and multiple right-hand sides per each left-hand + side. + On CPU, solution is computed via Gaussian elimination with or without partial + pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE + library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv + Partial pivoting is not yet supported by XLA backends. + }]; + + let arguments = (ins + TensorOf<[TF_Complex128, TF_Complex64, TF_Float32, TF_Float64]>:$diagonals, + TensorOf<[TF_Complex128, TF_Complex64, TF_Float32, TF_Float64]>:$rhs, + + DefaultValuedAttr:$partial_pivoting + ); + + let results = (outs + TensorOf<[TF_Complex128, TF_Complex64, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TruncateDivOp : TF_Op<"TruncateDiv", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise for integer types."; @@ -14401,6 +14621,27 @@ key: A unique identifier for this region used to match up host transfers. TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSortOp : TF_Op<"XlaSort", [NoSideEffect]> { + let summary = "Wraps the XLA Sort operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + }]; + + let arguments = (ins + Arg:$input + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -14766,4 +15007,4 @@ execution the transfer corresponds to.}]>:$dynamic_key, let results = (outs); TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 634004038d0..2b3eab72226 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -90,7 +90,9 @@ bool HasSingleUse(FuncOp func) { // Inspect function uses in the containing module and all parent // modules. bool use_seen = false; - for (; module; module = module.getParentOfType()) { + for (; module; module = func.isPrivate() + ? nullptr + : module.getParentOfType()) { auto func_uses_optional = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); // Found an unknown use. @@ -105,15 +107,36 @@ bool HasSingleUse(FuncOp func) { // This is the first use seen. use_seen = true; - - // If the function is private, no need to inspect parent modules. - if (func.isPrivate()) break; } // No multiple uses seen. return true; } +// Returns true if the caller ops can be inlined. +bool HasInlinableUsers(FuncOp func) { + // Return false if unexpected IR structure seen. + ModuleOp module = func.getParentOfType(); + if (!module) return false; + + // Inspect function uses in the containing module and all parent + // modules. + for (; module; module = func.isPrivate() + ? nullptr + : module.getParentOfType()) { + auto func_uses_optional = + SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + // Found an unknown use. + if (!func_uses_optional) return false; + + for (auto &use : func_uses_optional.getValue()) + if (isa(use.getUser())) return false; + } + + // All caller ops that can be inlined. + return true; +} + struct TFConstantFoldInterface : public DialectFoldInterface { TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {} LogicalResult fold(Operation *op, ArrayRef operands, @@ -153,11 +176,14 @@ struct TFInlinerInterface : public DialectInlinerInterface { BlockAndValueMapping &) const final { // An op is legal to inline if either of the following conditions is true: // (a) Its legal to duplicate the Op. - // (a) The Op is inside a single use function. If that function is inlined, + // (b) The Op is inside a single use function. If that function is inlined, // post inlining, the function will be dead and eliminated from the IR. // So there won't be any code duplication. + // plus the function caller op can be replaced by inlined ops. FuncOp func = op->getParentOfType(); - return !func || TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func); + if (!func) return true; + if (!HasInlinableUsers(func)) return false; + return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func); } //===--------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 67ad0fc4e70..c814153eb43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -178,6 +178,9 @@ An n-way switch statement, implementing the following: let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; + } // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with @@ -387,6 +390,12 @@ else_branch: A region that computes the outputs of the op if cond = false. return Verify(*this); }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange operands, llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions", [{ + assert(numRegions == 2u && "mismatched number of regions"); + build(builder, result, resultTypes, operands, attributes); + }]>]; + let hasCanonicalizer = 1; } @@ -1626,7 +1635,7 @@ event: A string containing a binary-encoded tf.Event proto. let results = (outs); } -def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> { +def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [TF_ResourceHandleAllocatorInterface]> { let summary = "Returns a handle to be used to access a summary writer."; let description = [{ @@ -1644,6 +1653,13 @@ writer: the summary writer resource. Scalar handle. let results = (outs Res:$writer ); + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; } def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index df23638d186..9d523640d6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -454,6 +454,20 @@ Optional BiasAddOp::GetContractionFusion() { return ContractionFusion("BiasAdd", /*additional_arguments=*/{1}); } +LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Prefer NHWC for GPU devices. + return "NHWC"; +} + //===----------------------------------------------------------------------===// // BiasAddGradOp //===----------------------------------------------------------------------===// @@ -659,6 +673,75 @@ static LogicalResult Verify(CaseRegionOp op) { return success(); } +namespace { +// Eliminate values that pass through the CaseRegionOp or IfRegionOp branches. +template +class CaseOrIfRegionEliminatePassThrough + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CaseOrIfRegionOp op, + PatternRewriter &rewriter) const override { + RegionRange branches = op.getRegions(); + SmallVector new_result_types; + // Maps pass through results to extern values. + llvm::SmallDenseMap result_to_extern_value; + + for (auto result : op.getResults()) { + unsigned index = result.getResultNumber(); + Region *first_branch = *branches.begin(); + Operation *first_terminator = first_branch->front().getTerminator(); + Value returned_val = first_terminator->getOperand(index); + + // Pass through values would be defined outside the branch region. Keep + // the type of non pass through results to create a new op later, if + // required. + if (returned_val.getParentBlock() == &first_branch->front()) { + new_result_types.push_back(result.getType()); + continue; + } + // Check if the same extern value is returned in each branch. + for (Region *region : branches.drop_front()) { + Operation *terminator = region->front().getTerminator(); + if (terminator->getOperand(index) != returned_val) return failure(); + } + result_to_extern_value[result] = returned_val; + } + + // If no pass through values are found, no change is required. + if (result_to_extern_value.empty()) return failure(); + + // Create new case/if region op. + auto new_op = rewriter.create( + op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(), + op.getNumRegions()); + + int next_index = 0; + for (auto result : op.getResults()) { + if (!result_to_extern_value.count(result)) { + result.replaceAllUsesWith(new_op.getResult(next_index++)); + continue; + } + result.replaceAllUsesWith(result_to_extern_value[result]); + for (Region *branch : branches) + branch->front().getTerminator()->eraseOperand(next_index); + } + + // Move region bodies to the new op. + for (auto region_index : llvm::seq(0, branches.size())) + new_op.getRegion(region_index).takeBody(op.getRegion(region_index)); + + op.erase(); + return success(); + } +}; +} // namespace + +void CaseRegionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert>(context); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// @@ -2112,7 +2195,8 @@ LogicalResult FoldConstantIfRegionOp::matchAndRewrite( void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index 44df2b12d88..72ca50b5c37 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -587,3 +587,31 @@ struct DropAttributes : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// TF op helper functions for handling resource handles and ids. +//===----------------------------------------------------------------------===// + +// Returns device of op if present. If op has no device set, an empty string ref +// is returned instead. +llvm::StringRef GetDeviceOrEmpty(Operation *op) { + if (auto device_attr = op->getAttrOfType("device")) + return device_attr.getValue(); + return llvm::StringRef(); +} + +// Returns resource handle value and id for resource op based on attributes. If +// a resource handle is anonymous, a new id is always returned. +ResourceHandleValueAndId GetResourceHandleValueAndIdBase( + llvm::StringRef container, llvm::StringRef shared_name, + llvm::StringRef device, Value resource, + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++}; + + ResourceHandle handle(container, shared_name, device, /*op=*/nullptr); + auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++next_id; + return {resource, emplace_res.first->second}; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 8742a0e2b71..b99c99029ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -35,6 +36,7 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -580,131 +582,143 @@ Optional ReluOp::GetContractionFusion() { // ReshapeOp //===----------------------------------------------------------------------===// -// TODO(b/128020684): Verify the output type. -static LogicalResult Verify(ReshapeOp op) { - auto shape_type = op.shape().getType().cast(); - if (!shape_type.hasRank()) return success(); - if (shape_type.getRank() != 1) - return op.emitOpError("shape must be 1D tensor"); - auto rank_by_shape = shape_type.getShape()[0]; - auto type_of_tensor = op.tensor().getType().cast(); - // No compile time verification for unknown sized shape. - if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); - int64_t num_by_tensor = type_of_tensor.getNumElements(); +namespace { +using ReshapeErrorHandler = + llvm::function_ref; - auto out_ty = op.getType().dyn_cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_by_tensor != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") does not match expected number of elements (" - << num_by_tensor << ")"; - } +LogicalResult GetReshapeOutputType(Value tensor, Value shape, + ReshapeErrorHandler error_handler, + TensorType &output_ty) { + auto tensor_ty = tensor.getType().cast(); + auto element_ty = tensor_ty.getElementType(); + output_ty = UnrankedTensorType::get(element_ty); - // Check values if constant shape. No compiling time verification for - // non-constant shape. - auto *shape_op = op.shape().getDefiningOp(); - if (!shape_op) return success(); - Attribute shape_cst; - if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); - auto shape_cst_attr = shape_cst.dyn_cast(); - if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); + auto shape_ty = shape.getType().dyn_cast(); + if (!shape_ty) return success(); + if (shape_ty.getRank() != 1) + return error_handler(llvm::formatv( + "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank())); - if (auto opaque_attr = shape_cst_attr.dyn_cast()) { - opaque_attr.decode(shape_cst_attr); - } - - // We know the shape is a 1-D Tensor, then let us get the number of - // elements it implies. - unsigned num_by_shape = 1; - unsigned unknown_dim_count = 0; - for (int i = 0, e = rank_by_shape; i != e; ++i) { - auto num = shape_cst_attr.getValue(i).getInt(); - // The dimension size value can be -1, and that the real size needs to - // be computed so that the total size remains constant. At most one - // component of shape can be -1. - if (num == -1) { - if (++unknown_dim_count > 1) { - return op.emitOpError("more than one component of shape are -1"); - } - } else { - num_by_shape *= num; + DenseIntElementsAttr shape_attr; + if (!matchPattern(shape, m_Constant(&shape_attr))) { + // If only shape of `shape` is known, return ranked but dynamic output + // shape. + if (shape_ty.hasStaticShape()) { + llvm::SmallVector dynamic_shape(shape_ty.getDimSize(0), + ShapedType::kDynamicSize); + output_ty = RankedTensorType::get(dynamic_shape, element_ty); } - } - // If there is one component of shape is -1, the dimension should be - // computed so that the total size remains constant. - if (unknown_dim_count == 1) { - if (num_by_tensor % num_by_shape != 0) - return op.emitOpError( - "one component of shape is -1 but couldn't infer the dimension"); return success(); } - // If the elements by the tensor and implies by the shape don't match, - // fail this static check. - if (num_by_tensor != num_by_shape) { - return op.emitOpError( - "mismatch in tensor elements and shape implied elements"); + + // Detect if reshape output shape is folded. + bool shape_ty_zero_dim = false; + int unknown_index = -1; + // The product of constant shape argument excluding unknown dimension. + int64_t shape_ty_size = 1; + llvm::SmallVector output_ty_shape; + output_ty_shape.reserve(shape_attr.getNumElements()); + for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) { + const int64_t size = dim.value().getSExtValue(); + if (size == ShapedType::kDynamicSize) { + if (unknown_index != -1) + return error_handler(llvm::formatv( + "requires 'shape' to have at most one dynamic dimension, but got " + "multiple dynamic dimensions at indices {0} and {1}", + unknown_index, dim.index())); + + unknown_index = dim.index(); + } else if (size == 0) { + shape_ty_zero_dim = true; + } else if (size > 0) { + shape_ty_size *= size; + } else { + return error_handler( + llvm::formatv("requires 'shape' to have dimensions greater than -1, " + "but got {0} at index {1}", + size, dim.index())); + } + output_ty_shape.push_back(size); } + + if (!tensor_ty.hasStaticShape()) { + output_ty = RankedTensorType::get(output_ty_shape, element_ty); + return success(); + } + + // Compute the value of the unknown dimension. + if (unknown_index != -1) { + // Compute number of elements in tensor shape. + int64_t tensor_ty_size = 1; + bool tensor_ty_zero_dim = false; + for (const auto &dim : tensor_ty.getShape()) { + if (dim > 0 || !shape_ty_zero_dim) { + tensor_ty_size *= dim; + } else { + tensor_ty_zero_dim = true; + } + } + + const int64_t missing_dim = tensor_ty_size / shape_ty_size; + if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size) + return error_handler( + llvm::formatv("requires 'tensor' number of elements be a multiple of " + "{0}, but got {1}", + shape_ty_size, tensor_ty_size)); + + // Set the unknown dimension such that total number of elements remain + // constant. + output_ty_shape[unknown_index] = missing_dim; + } + + output_ty = RankedTensorType::get(output_ty_shape, element_ty); + + return success(); +} +} // namespace + +static LogicalResult Verify(ReshapeOp op) { + auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult { + return op.emitOpError() << message; + }; + TensorType expected_ty; + if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler, + expected_ty))) + return failure(); + + auto output_ty = op.getType().dyn_cast(); + if (!output_ty) return success(); + auto tensor_ty = op.tensor().getType().cast(); + if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) { + const int64_t output_ty_size = output_ty.getNumElements(); + const int64_t tensor_ty_size = tensor_ty.getNumElements(); + if (tensor_ty_size != output_ty_size) + return op.emitOpError() << "requires 'output' number of elements to " + "match 'tensor' number of elements, but got " + << output_ty_size << " and " << tensor_ty_size; + } + + if (!AreCastCompatible({output_ty, expected_ty})) + return op.emitOpError() + << "requires 'output' type " << output_ty + << " to be cast compatible with expected type " << expected_ty; + return success(); } +// Currently there are use cases that rely on partial evaluation of the `shape` +// operand, so InferTypeOpInterface is not used (along with generated builder of +// the same signature). void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, Value shape) { - auto ttype = tensor.getType().cast(); - auto etype = ttype.getElementType(); - - auto unranked = [&builder, etype, &result, shape, tensor]() { - return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), - tensor, shape); + auto error_handler = [&result](const llvm::Twine &message) { + return mlir::emitError(result.location) << message; }; + TensorType output_ty; + if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty))) + return; - // If tensor is unranked then we have no info about output of shape. - if (!ttype.hasRank()) return unranked(); - - DenseIntElementsAttr attr_shape; - if (matchPattern(shape, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - const_shape.reserve(attr_shape.getNumElements()); - - // Detect if reshape output shape is folded. - bool flatten = false; - int unknown_index = -1; - // The product of constant shape argument excluding unknown dimension. - int64_t product_cshape = 1; - for (auto e : llvm::enumerate(attr_shape)) { - int64_t val = e.value().getSExtValue(); - if (IsUnknownDimOrRank(val)) { - if (flatten) { - mlir::emitError(result.location) - << "only one unknown dimension allowed"; - return; - } - flatten = true; - unknown_index = e.index(); - } else { - product_cshape *= val; - } - const_shape.push_back(val); - } - - // Compute the value of the unknown dimension. - if (flatten) { - // Compute number of elements in tensor shape. - auto tshape = ttype.getShape(); - int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, - std::multiplies()); - // Set the unknown dimension such that total number of elements remain - // constant. - // Note: The case where the ratio is not integral, and so the total size - // of reshape not constant, is checked in verify function. - const_shape[unknown_index] = product_tshape / product_cshape; - } - return ReshapeOp::build(builder, result, - RankedTensorType::get(const_shape, etype), tensor, - shape); - } - return unranked(); + return ReshapeOp::build(builder, result, output_ty, tensor, shape); } void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -1915,6 +1929,19 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( return true; } +//===----------------------------------------------------------------------===// +// SummaryWriterOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + writer(), resource_handle_id_map, + next_id); +} + //===----------------------------------------------------------------------===// // TensorListReserveOp //===----------------------------------------------------------------------===// @@ -2322,6 +2349,41 @@ void NonMaxSuppressionV3Op::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// FusedBatchNormOp +//===----------------------------------------------------------------------===// + +namespace { + +class ConvertFusedBatchNorm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, + PatternRewriter &rewriter) const override { + auto new_result_types = + llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); + // reserve_space_3 + new_result_types.push_back( + UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); + + OperationState new_state(tf_fused_batch_norm_op.getLoc(), + TF::FusedBatchNormV3Op::getOperationName(), + tf_fused_batch_norm_op.getOperands(), + new_result_types, + tf_fused_batch_norm_op.getAttrs()); + Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); + + rewriter.replaceOp(tf_fused_batch_norm_op, + tf_fused_batch_norm_op_v3->getResults().drop_back()); + return success(); + } +}; +} // namespace. + +void FusedBatchNormOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // UnpackOp //===----------------------------------------------------------------------===// @@ -2396,18 +2458,10 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId( llvm::SmallDenseMap &resource_handle_id_map, int64_t &next_id) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++}; - - llvm::StringRef device; - if (auto device_attr = getAttrOfType("device")) - device = device_attr.getValue(); - - ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr); - auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++next_id; - return {resource(), emplace_res.first->second}; + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + resource(), resource_handle_id_map, + next_id); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index b0d96963088..e77dd365abf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -975,6 +975,65 @@ func @foldIfRegionMismatchedTypes(%arg0: tensor, %arg1: tensor, %a return %0 : tensor<1xf32> } +// CHECK-LABEL: func @eliminatePassThroughIfRegion( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +func @eliminatePassThroughIfRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + // CHECK: %[[PRED:.*]] = "tf._SomeOp"() : () -> tensor + %pred = "tf._SomeOp"() : () -> tensor + // CHECK: %[[IF_OUTPUT:.*]] = "tf.IfRegion"(%[[PRED]]) ( { + // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[MUL]]) : (tensor) + // CHECK: }, { + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[SUB]]) : (tensor) + // CHECK: }) {is_stateless = true} : (tensor) -> tensor + %0:4 = "tf.IfRegion"(%pred) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %arg2, %true_value, %arg2) : (tensor, tensor, tensor, tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %arg2, %false_value, %arg2) : (tensor, tensor, tensor, tensor) -> () + }) { is_stateless = true}: (tensor) -> (tensor, tensor, tensor, tensor) + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[ARG1]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#1, %0#0) : (tensor, tensor) -> () + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[IF_OUTPUT]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#3, %0#2) : (tensor, tensor) -> () + // CHECK: return %[[IF_OUTPUT]] : tensor + return %0#2 : tensor +} + +// CHECK-LABEL: func @eliminatePassThroughCaseRegion( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +func @eliminatePassThroughCaseRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + // CHECK: %[[INDEX:.*]] = "tf._SomeOp"() : () -> tensor + %index = "tf._SomeOp"() : () -> tensor + // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[INDEX]]) ( { + // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[MUL]]) : (tensor) + // CHECK: }, { + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[SUB]]) : (tensor) + // CHECK: }, { + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) + // CHECK: "tf.Yield"(%[[ADD]]) : (tensor) + // CHECK: }) {is_stateless = true} : (tensor) -> tensor + %0:3 = "tf.CaseRegion"(%index) ({ + %mul = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %mul, %arg2) : (tensor, tensor, tensor) -> () + }, { + %sub = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %sub, %arg2) : (tensor, tensor, tensor) -> () + }, { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%arg1, %add, %arg2) : (tensor, tensor, tensor) -> () + }) { is_stateless = true}: (tensor) -> (tensor, tensor, tensor) + // CHECK: "tf._SomeOp"(%[[ARG2]], %[[ARG1]]) : (tensor, tensor) -> () + "tf._SomeOp"(%0#2, %0#0) : (tensor, tensor) -> () + // CHECK: return %[[CASE_OUTPUT]] : tensor + return %0#1 : tensor +} + + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor @@ -1225,3 +1284,10 @@ func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tens %0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>) return %0 : tensor<2xi32> } + +// CHECK-LABEL: testFusedBatchNormToBatchNormV3 +func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "tf.FusedBatchNormV3" + %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> ) + return %0#0 : tensor<8x8x8x8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt index 5fb90b1bce0..b8c779992ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/graph-resource.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=assign_variable | tf-mlir-translate -mlir-tf-graph-to-hlo-text -tf-input-shapes=2:2 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-xla-input-types=parameter,resource -emit-return-tuple | FileCheck %s node { name: "arg0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir new file mode 100644 index 00000000000..afc9e1e51ed --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir @@ -0,0 +1,43 @@ +// RUN: tf-opt -tf-broadcast-fold %s | FileCheck %s + +// CHECK-LABEL: @broadcast_mul0 +func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul1 +func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%0, %arg1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_add_implicit_fold +func @broadcast_add_implicit_fold(%arg0: tensor<5x1xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.AddV2"(%arg0, %0) : (tensor<5x1xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<5x1xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul_implicit_no_fold +func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32>) -> tensor<3x5x7xf32> { + %cst = constant dense<[3, 5, 7]> : tensor<3xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + return %1 : tensor<3x5x7xf32> + // CHECK: %[[C0:.*]] = constant dense<[3, 5, 7]> : tensor<3xi32> + // CHECK: %[[V0:.*]] = "tf.BroadcastTo"(%arg1, %[[C0]]) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1]] : tensor<3x5x7xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir index c52488b4afc..1f0a183c19e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s +// RUN: tf-opt %s -split-input-file -tf-executor-graph-pruning | FileCheck %s // Two islands chained by data-flow contributing to the graph return are // preserved. @@ -18,20 +18,6 @@ func @chained_islands(%arg0 : i32) -> i32 { return %0 : i32 } -// Check that a function that does not have arguments/results is ignored by -// thep pruning pass: this could be a V1 graph imported without feeds/fetches. -// CHECK-LABEL: func @empty_islands( -func @empty_islands() { -// CHECK: tf_executor.island - tf_executor.graph { - %0 = tf_executor.island { - tf_executor.yield - } - tf_executor.fetch - } - return -} - // Check that an unused island that doesn't contribute to the fetch is removed. // CHECK-LABEL: func @dead_island( func @dead_island(%arg0 : i32) -> i32 { @@ -165,3 +151,37 @@ func @control_fetch(%arg0 : i32) { } return } + +// ----- + +// Check that a function that is named "main" and does not have the +// "tf.entry_function" attribute defined is ignored by the pruning pass: this +// could be a V1 graph imported without feed/fetch/target nodes. +// CHECK-LABEL: func @main( +func @main() { +// CHECK: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// ----- + +// Check that a function that is named "main" and does have the +// "tf.entry_function" attribute defined with no feed/fetch/target nodes is +// pruned. +// CHECK-LABEL: func @main( +func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = ""}} { +// CHECK-NOT: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir index 7e583d0425a..67b4691f296 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir @@ -15,6 +15,23 @@ func @inline_simple() -> tensor<2xi32> { return %result : tensor<2xi32> } +// Test that TPUParitionedCallOp is not inlined. + +func @simple_callee() -> tensor<2xi32> attributes {sym_visibility = "private"} { + %cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32> + return %cst : tensor<2xi32> +} + +// CHECK-LABEL: func @dont_inline_tpu_partitioned_call( +func @dont_inline_tpu_partitioned_call() -> tensor<2xi32> { + // CHECK-NEXT: %[[ORDINAL:.*]] = "tf.TPUOrdinalSelector" + // CHECK-NEXT: %[[PARTITIONED_CALL:.*]] = "tf.TPUPartitionedCall"(%[[ORDINAL]]) + // CHECK-NEXT: return %[[PARTITIONED_CALL]] + %0 = "tf.TPUOrdinalSelector"() {device = ""} : () -> tensor + %result = "tf.TPUPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @simple_callee} : (tensor) -> tensor<2xi32> + return %result : tensor<2xi32> +} + // Check that TF call operations can be inlined, even when the shape of the // argument or result is different than the called function. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir index c71d8ef2850..0034d3f4308 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir @@ -70,3 +70,15 @@ func @transposeFusedBatchNormV3( return %y : tensor<1x64x28x28xf32> } + +// CHECK-LABEL: bias_add_nchw +func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) -> tensor<1x256x150x150xf32> { + // CHECK: (%[[ARG0:.*]]: tensor<1x256x150x150xf32>, %[[ARG1:.*]]: tensor<256xf32>) + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} + // CHECK: %[[R0:.*]] = "tf.Transpose"(%[[ARG0]], %[[CST]]) + // CHECK: %[[R1:.*]] = "tf.BiasAdd"(%[[R0]], %[[ARG1]]) {data_format = "NHWC", device = ""} + // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} + // CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]]) + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32> + return %0 : tensor<1x256x150x150xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index fb4b24037a4..fcd2f2512fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -515,9 +515,11 @@ func @addN_variant(%arg0: tensor>>, %arg1: tensor) -> tensor<2x2xf32> { - // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) - // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor) -> tensor<2x2xf32> + // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) + // CHECK: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor) -> tensor<2x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -533,7 +535,12 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x // CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32> // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor) -> tensor<5x2xf32> + // CHECK-DAG: %[[ITEMS1_3:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#3, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS1_2:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#2, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS1_1:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#1, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS1_0:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#0, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS0_0:.*]] = "tf.ExpandDims"(%[[ITEMS0]], %[[AXIS]]) + // CHECK-DAG: "tf.ConcatV2"(%[[ITEMS1_3]], %[[ITEMS1_2]], %[[ITEMS1_1]], %[[ITEMS1_0]], %[[ITEMS0_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor) -> tensor<5x2xf32> %indices0 = "tf.Const"() {value = dense<4> : tensor} : () -> tensor %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> @@ -555,7 +562,9 @@ func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> { func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor, tensor) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor, tensor, tensor) -> tensor<2xf32> + // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) + // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1xf32>, tensor<1xf32>, tensor) -> tensor<2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -567,7 +576,9 @@ func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x2x2xf32> + // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) + // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2x2xf32>, tensor<1x2x2xf32>, tensor) -> tensor<2x2x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -586,7 +597,8 @@ func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tenso func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> + // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[AXIS]]) : (tensor<1x2xf32>, tensor) -> tensor<1x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -629,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex>) -> tensor<*xcomplex, %arg1: tensor<4xf32>) -> tensor<8xf32> { // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> - // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> + // CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32> %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32> @@ -693,3 +705,14 @@ func @round_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Round"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// CHECK-LABEL: func @lgamma +func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // The lowering for lgamma is complicated, which makes it awkward to write a + // complete test for it here. Instead we test that Lgamma is at least being + // lowered here and rely on UnaryOpsTest.testFloatOps and other TensorFlow + // tests to check it is lowered correctly and with sufficient precision. + // CHECK-NOT: tf.Lgamma + %0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir index 52dc06cd393..03cac7dbd33 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir @@ -1,13 +1,12 @@ // RUN: tf-opt %s -tf-parallel-execute-to-islands | FILECHECK_OPTS="" FileCheck %s -// CHECK-LABEL: func @check_regions_to_islands -func @check_regions_to_islands() { +// CHECK-LABEL: func @testEmptyRegions +func @testEmptyRegions() { tf_executor.graph { tf_executor.island() { "tf_device.parallel_execute"() ({ tf_device.return - }, - { + }, { tf_device.return }) {} : () -> () tf_executor.yield @@ -17,210 +16,133 @@ func @check_regions_to_islands() { return } -// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK: [[ISLAND_0_CTRL:%.+]] = tf_executor.island { // CHECK: tf_executor.yield -// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { +// CHECK: [[ISLAND_1_CTRL:%.+]] = tf_executor.island { // CHECK: tf_executor.yield -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield +// CHECK: tf_executor.fetch [[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]] : -// CHECK-LABEL: func @check_regions_to_islands_with_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_regions_to_islands_with_inputs(%arg0 : tensor) { - tf_executor.graph { +// CHECK-LABEL: func @testDataOperandsAndResults +// CHECK-SAME: ([[ARG_0:%.+]]: tensor) +func @testDataOperandsAndResults(%arg0 : tensor) { + %0:2 = tf_executor.graph { %1:2 = tf_executor.island { %2 = "tf.opA"(%arg0) : (tensor) -> tensor tf_executor.yield %2 : tensor } - tf_executor.island() { - "tf_device.parallel_execute"() ({ - %3 = "tf.opB"(%1#0) : (tensor) -> tensor - tf_device.return %3 : tensor - }, - { + %3:3 = tf_executor.island() { + %4:2 = "tf_device.parallel_execute"() ({ + %5 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }, { %5 = "tf.opC"(%1#0) : (tensor) -> tensor tf_device.return %5 : tensor }) {} : () -> (tensor, tensor) - tf_executor.yield + tf_executor.yield %4#0, %4#1 : tensor, tensor } - tf_executor.fetch + tf_executor.fetch %3#0, %3#1 : tensor, tensor } return } -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield +// CHECK: [[INPUT_A:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) +// CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_C_OUTPUT:%.+]] = "tf.opC"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_C_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : -// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %7 = tf_executor.ControlTrigger {} - %8 = tf_executor.ControlTrigger {} - tf_executor.island(%7, %8) { - "tf_device.parallel_execute"() ({ - %3 = "tf.opB"(%1#0) : (tensor) -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"() : () -> tensor - tf_device.return %5 : tensor - }) {} : () -> (tensor, tensor) +// CHECK-LABEL: func @testControlOperands +func @testControlOperands() { + %0:2 = tf_executor.graph { + %1 = tf_executor.island { tf_executor.yield } - tf_executor.fetch - } - return -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { -// CHECK-NEXT: tf_executor.yield - - -// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %7:3 = tf_executor.island() { - %8:2 = "tf_device.parallel_execute"() ( - { - %3 = "tf.opB"() : () -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"(%1#0) : (tensor) -> tensor - tf_device.return %5 : tensor - } - ) {} : () -> (tensor, tensor) - - tf_executor.yield %8#0, %8#1 : tensor, tensor - } - - tf_executor.island { - "tf.opD"(%7#0, %7#1) : (tensor, tensor) -> () - tf_executor.yield - } - tf_executor.fetch - } - return -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %{{.*}} = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] - - -// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs -func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %8:3 = tf_executor.island() { - %7:2 = "tf_device.parallel_execute"() ({ - %3 = "tf.opB"() : () -> tensor - tf_device.return %3 : tensor - }, - { - %5 = "tf.opC"(%1#0) : (tensor) -> tensor - tf_device.return %5 : tensor - }) {} : () -> (tensor, tensor) - tf_executor.yield %7#0, %7#1 : tensor, tensor - } - tf_executor.fetch %8#0 : tensor - } - return %0 : tensor -} - -// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor) -> tensor -// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor -// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor -// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) { -// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor -// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor -// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island { -// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor) -> tensor -// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor -// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor, tensor - -// CHECK-LABEL: func @check_parallel_execute_using_args -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) -func @check_parallel_execute_using_args(%arg0 : tensor) { - tf_executor.graph { - %1:2 = tf_executor.island { - %2 = "tf.opA"(%arg0) : (tensor) -> tensor - tf_executor.yield %2 : tensor - } - %2:2 = tf_executor.island { - %3 = "tf.opB"(%arg0) : (tensor) -> tensor - tf_executor.yield %3 : tensor - } - tf_executor.island() { - "tf_device.parallel_execute"() ({ - %4 = "tf.opC"(%arg0, %1#0) : (tensor, tensor) -> tensor + %2:3 = tf_executor.island(%1) { + %3:2 = "tf_device.parallel_execute"() ({ + %4 = "tf.opA"() : () -> tensor tf_device.return %4 : tensor - }, - { - %5 = "tf.opD"(%arg0, %2#0) : (tensor, tensor) -> tensor - tf_device.return %5 : tensor + }, { + %4 = "tf.opB"() : () -> tensor + tf_device.return %4 : tensor }) {} : () -> (tensor, tensor) - tf_executor.yield + tf_executor.yield %3#0, %3#1 : tensor, tensor } - tf_executor.fetch + tf_executor.fetch %2#0, %2#1 : tensor, tensor } return } -// Verify that args are directly accessed in newly created island without alias -// through entry barrier. +// CHECK: [[INPUT_CTRL:%.+]] = tf_executor.island { +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : -// CHECK: "tf.opC"(%[[ARG_0]] -// CHECK: "tf.opD"(%[[ARG_0]] + +// CHECK-LABEL: func @testControlResults +func @testControlResults() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = "tf_device.parallel_execute"() ({ + %2 = "tf.opA"() : () -> tensor + tf_device.return %2 : tensor + }, { + %2 = "tf.opB"() : () -> tensor + tf_device.return %2 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + %3 = tf_executor.island(%0#2) { + tf_executor.yield + } + tf_executor.fetch %3 : !tf_executor.control + } + return +} + +// CHECK: {{%.+}}, [[ISLAND_0_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[OUTPUT_CTRL:%.+]] = tf_executor.island([[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]]) { +// CHECK: [[FETCH_ISLAND:%.+]] = tf_executor.island([[OUTPUT_CTRL]]) { +// CHECK: tf_executor.fetch [[FETCH_ISLAND]] : !tf_executor.control + + +// CHECK-LABEL: func @testSomeRegionNoUsers +func @testSomeRegionNoUsers() { + %0 = tf_executor.graph { + %1:3 = tf_executor.island { + %2:2 = "tf_device.parallel_execute"() ({ + %3 = "tf.opA"() : () -> tensor + tf_device.return %3 : tensor + }, { + %3 = "tf.opB"() : () -> tensor + tf_device.return %3 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %2#0, %2#1 : tensor, tensor + } + tf_executor.fetch %1#0 : tensor + } + return +} + +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_CTRL]] : diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 8457d9c62cd..6cda668ab0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -82,6 +82,38 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // ----- +// Tests that a resource ops with both load and store are hoisted +// but input to load and output from store have mixed defined/undefined shapes. + +// CHECK-LABEL: func @same_resource_load_and_store_cast +func @same_resource_load_and_store_cast() -> tensor<1xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) + // CHECK: %[[CAST_RES:[0-9]*]] = "tf.Cast"(%[[COMPUTE_RES]]) + // CHECK: tf_device.return %[[CAST_RES]], %[[COMPUTE_RES]] + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<1xi32>, tensor<*xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32> + %3 = "tf.SomeComputation"(%2) : (tensor<1xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () + %4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32> + tf_device.return %4 : tensor<1xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<1xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + return %1 : tensor<1xi32> +} + +// ----- + // Tests that internal resource operations are not hoisted. // CHECK-LABEL: func @internal_resource @@ -1091,3 +1123,129 @@ func @type_refinement_use_refined_type() -> tensor<4xi32> { return %1 : tensor<4xi32> } +// ----- + +!tf_res = type tensor<*x!tf.resource>> + +// Test all tf.VarIsInitializedOp's are set to true. +// CHECK-LABEL: func @tpu_computation +func @tpu_computation(%arg0: !tf_res, %arg1: tensor, %arg2: tensor) { + %0 = "tf_device.cluster"() ( { + %1 = "tf.Case"(%arg2, %arg0) {branches = [@case_branch], is_stateless = false} : (tensor, !tf_res) -> tensor + + // CHECK: "tf.CaseRegion" + %2 = "tf.CaseRegion"(%arg2) ( { + // CHECK-NEXT: [[CASE_REGION_BRANCH:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %3 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[CASE_REGION_BRANCH]]) + "tf.Yield"(%3) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %4 = "tf.If"(%arg1, %arg0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, !tf_res) -> tensor + + // CHECK: "tf.IfRegion" + %5 = "tf.IfRegion"(%arg1) ( { + // CHECK-NEXT: [[IF_REGION_THEN:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %6 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[IF_REGION_THEN]]) + "tf.Yield"(%6) : (tensor) -> () + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: [[IF_REGION_ELSE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %7 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[IF_REGION_ELSE]]) + "tf.Yield"(%7) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %8:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res, tensor) -> (!tf_res, tensor) + + // CHECK: "tf.WhileRegion" + %9 = "tf.WhileRegion"(%arg1) ( { + // CHECK-NEXT: ^{{.+}}({{.+}}: tensor): + ^cond(%carg0: tensor): + // CHECK-NEXT: [[WHILE_REGION_COND:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %10 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[WHILE_REGION_COND]]) + "tf.Yield"(%10) : (tensor) -> () + // CHECK-NEXT: }, { + }, { + // CHECK-NEXT: ^{{.+}}({{.+}}: tensor): + ^body(%barg0: tensor): + // CHECK-NEXT: [[WHILE_REGION_BODY:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %11 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: "tf.Yield"([[WHILE_REGION_BODY]]) + "tf.Yield"(%11) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + + %12 = "tf.StatefulPartitionedCall"(%arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} : (!tf_res) -> tensor + + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %13 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + + // CHECK: tf_device.return [[TRUE]] : + tf_device.return %13 : tensor + }) : () -> tensor + return +} + +// CHECK-LABEL: func @case_branch +func @case_branch(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @if_then +func @if_then(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @if_else +func @if_else(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @while_cond +// CHECK-SAME: ({{.+}}: tensor) +func @while_cond(%arg0: !tf_res, %arg1: tensor) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} + +// CHECK-LABEL: func @while_body +// CHECK-SAME: ({{.+}}: tensor) +func @while_body(%arg0: !tf_res, %arg1: tensor) -> (!tf_res, tensor) { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %arg0, %0 : !tf_res, tensor +} + +// CHECK-LABEL: func @callee +func @callee(%arg0: !tf_res) -> tensor { + // CHECK: [[TRUE:%.+]] = "tf.Const" + // CHECK-SAME: value = dense : tensor + %0 = "tf.VarIsInitializedOp"(%arg0) : (!tf_res) -> tensor + // CHECK-NEXT: return [[TRUE]] : + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir index e4fdad2eddb..17329050f3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir @@ -122,6 +122,108 @@ func @while_cond(%arg0: tensor, %arg1: tensor) -> tensor // ----- +// Tests WhileRegion Op. + +// CHECK-LABEL: func @main() +func @main() -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.Stack + // CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: tf.AssignVariableOp + // CHECK: tf.AssignVariableOp + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // CHECK: tf.WhileRegion + %while = "tf.WhileRegion"(%max_size) ({ + // CHECK: ^bb0(%[[BARG0:.*]]: tensor + ^bb0(%barg0: tensor): + // CHECK: "tf._SomeOp"(%[[BARG0]]) + %pred = "tf._SomeOp"(%barg0) : (tensor) -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + // CHECK: ^bb0(%[[BARG0:.*]]: tensor + ^bb0(%barg0: tensor): + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG0]], %[[CONST1]]) + %sub = "tf.Sub"(%barg0, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.StackPushV2" + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + // CHECK-NOT: "tf.StackPushV2" + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + // CHECK: "tf.Yield"(%[[SUB]]) + "tf.Yield"(%sub) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + // CHECK-NOT: tf.StackPopV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + // CHECK-NOT: tf.StackCloseV2 + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- + +// Test CaseRegionOp + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor +func @main(%arg0: tensor) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.StackV2 + // CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: tf.AssignVariableOp + // CHECK: tf.AssignVariableOp + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[BRANCH_INDEX]]) ( { + %case_op = "tf.CaseRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: tf.StackPushV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%elem) : (tensor) -> () + }, { + %elem = "tf._SomeOtherOp"() : () -> tensor + // CHECK-NOT: tf.StackPushV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%elem) : (tensor) -> () + }, { + // CHECK-NOT: tf.StackPopV2 + // CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) + // CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) + // CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[SIZE]] + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + "tf.Yield"(%pop) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + // CHECK-NOT: tf.StackPopV2 + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + // CHECK-NOT: tf.StackCloseV2 + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- // Tests IfOp. // CHECK-LABEL: func @main @@ -308,3 +410,53 @@ func @if_else(%arg0: tensor, %arg1: tensor) -> tenso %push = "tf.StackPushV2"(%arg1, %elem) {swap_memory = false} : (tensor, tensor) -> tensor return %arg1 : tensor } + +// ----- + +// Tests that the pass returns meaningful error message when WhileRegion op has +// resource arguments. +func @main() -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + %push_0 = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + // expected-error @+1 {{found unexpected type 'tensor>>' of operand #0, resource type operands are expected to have been canonicalized away for region based control flow ops}} + %1:2 = "tf.WhileRegion"(%stack, %max_size) ({ + ^bb0 (%carg0: tensor, %carg1: tensor): + %pred = "tf._SomeOp"(%carg1) : (tensor) -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + ^bb0 (%carg0: tensor, %carg1: tensor): + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %sub = "tf.Sub"(%carg1, %const1) : (tensor, tensor) -> tensor + %push_1 = "tf.StackPushV2"(%carg0, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%carg0, %sub) : (tensor, tensor) -> () + }) {is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + %pop = "tf.StackPopV2"(%1#0) : (tensor) -> tensor + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} + +// ----- + +// Tests that the pass returns meaningful error message when IfRegion op has +// resource returns. + +func @main(%arg0: tensor) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // expected-error @+1 {{found unexpected type 'tensor' of result #0, resource type results are expected to have been canonicalized away for region based control flow ops}} + %if_op = "tf.IfRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + "tf.Yield"(%stack) : (tensor) -> () + }, { + %pop = "tf.StackPopV2"(%stack) : (tensor) -> tensor + "tf.Yield"(%stack) : (tensor) -> () + }) {is_stateless = false} + : (tensor) -> tensor + %pop = "tf.StackPopV2"(%if_op) : (tensor) -> tensor + "tf.StackCloseV2"(%stack) : (tensor) -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index 0c4dc77cf69..8200cedaea9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -162,7 +162,7 @@ func @main() -> () { // ----- -// Test tensor list grads. +// Test tensor array grads. // CHECK-LABEL: func @main func @main() { @@ -314,6 +314,93 @@ func @else_branch(%arg0: tensor) -> tensor { // ----- +// Tests WhileRegion loop with access to the tensor array defined outside and +// its gradient defined inside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK-NOT: tf.TensorArrayV3 + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[FLOW_INIT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} + // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[FLOW_INIT]], %[[SIZE]]) ( { + %while:2 = "tf.WhileRegion"(%ta#1, %size) ({ + // CHECK: ^bb0(%[[BARG0:.*]]: tensor, %[[BARG1:.*]]: tensor): + ^bb0(%barg0: tensor, %barg1: tensor): + // CHECK: %[[PRED:.*]] = "tf._SomeOp"(%[[BARG1]]) + // CHECK: "tf.Yield"(%[[PRED]]) + %pred = "tf._SomeOp"(%barg1) : (tensor) -> tensor + "tf.Yield" (%pred) : (tensor) -> () + }, { + // CHECK: ^bb0(%[[BARG0:.*]]: tensor, %[[BARG1:.*]]: tensor): + ^bb0(%barg0: tensor, %barg1: tensor): + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %sub = "tf.Sub"(%barg1, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_VAR]], + // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[UPDATE]]) + %write = "tf.TensorArrayWriteV3"(%ta#0, %sub, %elem, %flow) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %write) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ_GVAR:.*]] = "tf.ReadVariableOp"(%[[GVAR]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_GVAR]], + // CHECK: "tf.AssignVariableOp"(%[[GVAR]], %[[UPDATE]]) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + "tf.Yield"(%gwrite, %sub) : (tensor, tensor) -> () + }) {is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]]) + // CHECK: "tf.Slice"(%[[READ_VAR]] + %read = "tf.TensorArrayReadV3"(%ta#0, %index, %while#0) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} + +// ----- + +// Test IfRegion op. + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[PRED:.*]]: tensor +func @main(%arg0: tensor) -> () { + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.TensorArrayV3 + // CHECK: %[[TA_BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.AssignVariableOp"(%[[TA_BUFFER]] + // CHECK-NOT: tf.TensorArrayV3 + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: "tf.IfRegion"(%[[PRED]]) ( { + %case_op = "tf.IfRegion"(%arg0) ({ + // CHECK: %[[TA_VAL:.*]] = "tf.ReadVariableOp"(%[[TA_BUFFER]]) + // CHECK: "tf.Slice"(%[[TA_VAL]] + // CHECK-NOT: tf.TensorArrayReadV3 + %idx = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + %read = "tf.TensorArrayReadV3"(%ta#0, %idx, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + "tf.Yield"(%ta#1) : (tensor) -> () + // CHECK: }, { + }, { + // CHECK: %[[TA_VAL:.*]] = "tf.ReadVariableOp"(%[[TA_BUFFER]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[TA_VAL]] + // CHECK: "tf.AssignVariableOp"(%[[TA_BUFFER]], %[[UPDATE]]) + // CHECK-NOT: tf.TensorArrayWriteV3 + %idx = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %write = "tf.TensorArrayWriteV3"(%ta#0, %idx, %elem, %ta#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + "tf.Yield"(%write) : (tensor) -> () + // CHECK: }) {is_stateless = false} : (tensor) -> tensor + }) {is_stateless = false} : (tensor) -> tensor + %idx = "tf.Const"() {value = dense<6> : tensor} : () -> tensor + // CHECK-NOT: tf.TensorArrayReadV3 + %read_val = "tf.TensorArrayReadV3"(%ta#0, %idx, %case_op) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} + +// ----- + // Tests (Stateful)PartitionedCall op with access to the tensor array defined // outside and its gradient defined inside. The gradient creation should be // moved outside. @@ -458,6 +545,59 @@ func @callee() -> (tensor<*xf32>) attributes {sym_visibility = "private"} { // CHECK: return %[[CAST]] : tensor<*xf32> return %val : tensor<*xf32> } +// ----- + +// Test CaseRegion with gradient inside PartitionedCall Op. The gradient local +// variable should be inserted before the PartitionedCall op. + +// CHECK-LABEL: func @main() +func @main() -> () { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK-NOT: tf.TensorArrayV3 + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR]]) + %call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK-NOT: tf.TensorArrayReadV3 + %read = "tf.TensorArrayReadV3"(%call, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} + +// CHECK-LABEL: func @callee +// CHECK-SAME: %[[VAR:.*]]: tensor>>, %[[GVAR:.*]]: tensor>> +func @callee(%arg0: tensor) -> tensor attributes {sym_visibility = "private"} { + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[BR_INDEX:.*]] = "tf.SomeOp"() : () -> tensor + %branch_index = "tf.SomeOp"() : () -> tensor + // CHECK: "tf.CaseRegion"(%[[BR_INDEX]]) ( { + "tf.CaseRegion"(%branch_index) ({ + // CHECK: %[[READ_GVAR:.*]] = "tf.ReadVariableOp"(%[[GVAR]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_GVAR]], + // CHECK: "tf.AssignVariableOp"(%[[GVAR]], %[[UPDATE]]) + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %index, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + "tf.Yield"() : () -> () + }, { + // CHECK: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]]) + // CHECK: "tf.Slice"(%[[READ_VAR]] + %read = "tf.TensorArrayReadV3"(%arg0, %index, %flow) : (tensor, tensor, tensor) -> tensor<3xf32> + "tf.Yield"() : () -> () + }, { + // CHECK: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_VAR]], + // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[UPDATE]]) + %write = "tf.TensorArrayWriteV3"(%arg0, %index, %elem, %flow) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + // CHECK: return %[[VAR]] + return %arg0 : tensor +} // ----- @@ -501,3 +641,46 @@ func @if_then(%arg0: tensor, %arg1: tensor) -> tenso func @if_else(%arg0: tensor, %arg1: tensor) -> tensor { return %arg1 : tensor } + +// ----- + +// Tests that the pass returns meaningful error message when region based +// control flow op has resource arguments. +func @main() -> () { + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // expected-error @+1 {{found unexpected type 'tensor>>' of operand #0, resource type operands are expected to have been canonicalized away for region based control flow ops}} + %1:2 = "tf.WhileRegion"(%ta#0, %size) ({ + ^bb0 (%carg0: tensor, %carg1: tensor): + %pred = "tf._SomeOp"(%carg1) : (tensor) -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + ^bb0 (%carg0: tensor, %carg1: tensor): + %idx = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + %read_true = "tf.TensorArrayReadV3"(%carg0, %idx, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + "tf.Yield"(%carg0, %idx) : (tensor, tensor) -> () + }) {is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + return +} + +// ----- + +// Tests that the pass returns meaningful error message when region based +// control flow op has resource returns. + +func @main(%arg0: tensor) -> (tensor<3xf32>) { + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // expected-error @+1 {{found unexpected type 'tensor' of result #1, resource type results are expected to have been canonicalized away for region based control flow ops}} + %if_op:2 = "tf.IfRegion"(%arg0) ({ + %idx = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + %read_true = "tf.TensorArrayReadV3"(%ta#0, %idx, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + "tf.Yield"(%read_true, %ta#0) : (tensor<3xf32>, tensor) -> () + }, { + %idx = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %read_false = "tf.TensorArrayReadV3"(%ta#0, %idx, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + "tf.Yield"(%read_false, %ta#0) : (tensor<3xf32>, tensor) -> () + }) {is_stateless = false} : (tensor) -> (tensor<3xf32>, tensor) + return %if_op : tensor<3xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 0f8137af672..8b97bfdad6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -224,17 +224,17 @@ func @testIncompatibleElementTypes(%arg0: tensor<3x2xf32>, %arg1: tensor<3x2xf32 // ----- // CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { +func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<100x100xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { %shape1 = constant dense<100> : tensor<2xi32> - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) - %shape2 = "tf.Shape"(%arg0) {device = "", name = "Shape", T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> (tensor) - %r2 = "tf.Reshape"(%arg1, %shape2) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>) - %r3 = "tf.Reshape"(%arg2, %shape1) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<10000xf32>) + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<100x100xf32> + %shape2 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + %r2 = "tf.Reshape"(%arg1, %shape2) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %r3 = "tf.Reshape"(%arg2, %shape1) : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32> %shape3 = constant dense<[-1, 100]> : tensor<2xi32> - %r4 = "tf.Reshape"(%arg2, %shape3) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) - %r5 = "tf.Reshape"(%arg0, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<*xf32>) - %r6 = "tf.Reshape"(%arg2, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<*xi32>) -> (tensor<*xf32>) - return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32> + %r4 = "tf.Reshape"(%arg2, %shape3) : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32> + %r5 = "tf.Reshape"(%arg0, %arg3) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + %r6 = "tf.Reshape"(%arg2, %arg3) : (tensor<10000xf32>, tensor<*xi32>) -> tensor<*xf32> + return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<100x100xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32> } // ----- @@ -243,25 +243,41 @@ func @testReshape(tensor<*xf32>, tensor<*xf32>) -> (tensor<100x100xf32>) { ^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>): %shape1 = constant dense<100.> : tensor<2xf32> // expected-error @+1 {{must be tensor of 32/64-bit signed integer values}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> (tensor<100x100xf32>) + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } // ----- // tf.Reshape with incorrect element number. -func @testReshape(%arg0: tensor<10x10x10xf32>) -> tensor<100x100xf32> { - %shape1 = constant dense<100> : tensor<2xi32> - // expected-error @+1 {{number of output elements (10000) does not match expected number of elements (1000)}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) +func @testReshape(%arg0: tensor<10x10x10xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + // expected-error @+1 {{requires 'output' number of elements to match 'tensor' number of elements, but got 10000 and 1000}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } +// ----- +// tf.Reshape with incorrect shape operand rank. +func @testReshape(%arg0: tensor<10x10x10xf32>, %shape1: tensor<2x2xi32>) -> tensor<*xf32> { + // expected-error @+1 {{requires 'shape' to be rank 1, but got 2}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2x2xi32>) -> tensor<*xf32> + return %r1 : tensor<*xf32> +} + // ----- // tf.Reshape with more than one -1 in the shape. func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { %shape1 = constant dense<-1> : tensor<2xi32> - // expected-error @+1 {{more than one component of shape are -1}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) + // expected-error @+1 {{requires 'shape' to have at most one dynamic dimension, but got multiple dynamic dimensions at indices 0 and 1}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + +// ----- +// tf.Reshape with shape operand element < -1. +func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { + %shape1 = constant dense<[100, -2]> : tensor<2xi32> + // expected-error @+1 {{requires 'shape' to have dimensions greater than -1, but got -2 at index 1}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } @@ -269,19 +285,68 @@ func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { // tf.Reshape with -1 in the shape can't infer the dimension. func @testReshape(%arg0: tensor<10x10x10x10xf32>) -> tensor<100x100xf32> { %shape1 = constant dense<[101, -1]> : tensor<2xi32> - // expected-error @+1 {{one component of shape is -1 but couldn't infer the dimension}} - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) + // expected-error @+1 {{requires 'tensor' number of elements be a multiple of 101, but got 10000}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> return %r1 : tensor<100x100xf32> } // ----- -// tf.Reshape with a first operand that has non-static shape. +// tf.Reshape with incorrect output rank. +func @testReshape(%arg0: tensor<10x10xf32>) -> tensor { + %shape1 = constant dense<[100]> : tensor<1xi32> + // expected-error @+1 {{requires 'output' type 'tensor' to be cast compatible with expected type 'tensor<100xf32>'}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10xf32>, tensor<1xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with incorrect output dimension. +func @testReshape(%arg0: tensor<1000xf32>) -> tensor { + %shape1 = constant dense<[10, 10, 10]> : tensor<3xi32> + // expected-error @+1 {{requires 'output' type 'tensor' to be cast compatible with expected type 'tensor<10x10x10xf32>'}} + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<1000xf32>, tensor<3xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a shape operand that has 0 for one of its elements. +func @testReshape(%arg0: tensor<10x10x10xf32>) -> tensor { + %shape1 = constant dense<[-1, 0]> : tensor<2xi32> + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10xf32>, tensor<2xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a tensor operand that has 0 for one of its elements. +func @testReshape(%arg0: tensor<10x10x0xf32>) -> tensor { + %shape1 = constant dense<[-1, 0]> : tensor<2xi32> + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x0xf32>, tensor<2xi32>) -> tensor + return %r1 : tensor +} + +// ----- +// tf.Reshape with a tensor operand that has non-static shape. func @testReshape(%arg0: tensor<10x10x?xf32>) -> tensor<10x10xf32> { %shape1 = constant dense<[10, 10]> : tensor<2xi32> - %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> (tensor<10x10xf32>) + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> tensor<10x10xf32> return %r1 : tensor<10x10xf32> } +// ----- +// tf.Reshape with tensor operand that has non-static shape and shape operand +// with static shape. +func @testReshape(%arg0: tensor<10x10x?xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + +// ----- +// tf.Reshape with tensor and shape operands with static shape. +func @testReshape(%arg0: tensor<10x10x10x10xf32>, %shape1: tensor<2xi32>) -> tensor<100x100xf32> { + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x10x10xf32>, tensor<2xi32>) -> tensor<100x100xf32> + return %r1 : tensor<100x100xf32> +} + // ----- // CHECK-LABEL: func @testValidAvgPool diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 281e4baaa12..3c2344be1e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -480,6 +480,53 @@ func @cluster_nested_op_using_resource() { // CHECK: "tf.opA"() ( { // CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + +// ----- + + +!tf_res = type tensor<*x!tf.resource>> + +// Test multiple replicated clusters interleaved and uses resource variables. +// CHECK-LABEL: func @multiple_replicated_interleaved +func @multiple_replicated_interleaved(%arg0: !tf_res) { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "a", num_replicas = 2, topology = "topology"} : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "b", num_replicas = 2, topology = "topology"} : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "c", num_replicas = 2, topology = "topology"} : () -> () + %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %1 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %2 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res + %3 = "tf.ReadVariableOp"(%0) {_tpu_replicate = "a"} : (!tf_res) -> tensor + %4 = "tf.ReadVariableOp"(%1) {_tpu_replicate = "b"} : (!tf_res) -> tensor + %5 = "tf.ReadVariableOp"(%2) {_tpu_replicate = "c"} : (!tf_res) -> tensor + %6 = "tf.Identity"(%3) {_tpu_replicate = "a"} : (tensor) -> tensor + %7 = "tf.Identity"(%4) {_tpu_replicate = "b"} : (tensor) -> tensor + %8 = "tf.Identity"(%5) {_tpu_replicate = "c"} : (tensor) -> tensor + %9:2 = "tf.TPUReplicatedOutput"(%6) : (tensor) -> (tensor, tensor) + %10:2 = "tf.TPUReplicatedOutput"(%7) : (tensor) -> (tensor, tensor) + %11:2 = "tf.TPUReplicatedOutput"(%8) : (tensor) -> (tensor, tensor) + return +} + +// CHECK: tf_device.replicate +// CHECK: tf_device.replicate +// CHECK: tf_device.replicate + + +// ----- + + +// Test cluster that is replicated but has a non TPUReplicatedOutput consumer. +// CHECK-LABEL: func @replicated_non_replicated_output +func @replicated_non_replicated_output() { + %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor + %1 = "tf.opB"(%0) : (tensor) -> tensor + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () + return +} + +// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate +// CHECK: "tf.opB"([[REPLICATE]]#0) + // ----- @@ -535,20 +582,6 @@ func @mismatched_replicated_output() { // ----- -// Test cluster that should be replicated where its outputs do not lead to a -// TPUReplicatedOutput. -func @missing_replicated_output() { - // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}} - %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor - %1 = "tf.opB"(%0) : (tensor) -> tensor - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () - return -} - - -// ----- - - // Test unused TPUReplicatedInput that has more than one operand. func @leftover_replicated_input(%arg0: tensor) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 2e3e38c7004..a3d5a43a214 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -227,3 +227,28 @@ func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor { %2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor) return %2 : tensor } + +// ----- + +// Tests that output sharding inside a functional op is parsed correctly. + +// CHECK-LABEL: func @check_sharding_inside_functional_op +func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>) { + "tf_device.cluster_func"(%arg0) {func = @cluster_func, step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32> + // CHECK: input_sharding_configuration + // CHECK-SAME: ["\01\02\03"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\01\02\03"] + return +} + +func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.PartitionedCall"(%arg0) {f= @func_body, config="", config_proto="", executor_type=""} : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> { + %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) + return %1 : tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index d5b7eb7a739..945573aa978 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -209,10 +209,8 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), (TF_ReshapeOp $arg, $shape)>; -def IsSame : Constraint>; -def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)), - (replaceWithValue $arg0), - [(IsSame $arg0, $arg1)]>; +def ReshapeToSelfShape : Pat<(TF_ReshapeOp $x, (TF_ShapeOp $x)), + (replaceWithValue $x)>; //===----------------------------------------------------------------------===// // Select op patterns. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc new file mode 100644 index 00000000000..66311101cee --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace { + +class ConvertResultsBroadcastableShapeOp : public RewritePattern { + public: + ConvertResultsBroadcastableShapeOp() + : RewritePattern(1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +class BroadcastFoldPass : public PassWrapper { + public: + void runOnFunction() override; +}; + +LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (!op->hasTrait()) return failure(); + if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1) + return failure(); + + // Check that the result shape is fully defined. + auto result_type = + op->getResultTypes().front().dyn_cast_or_null(); + if (!result_type || !result_type.hasStaticShape()) return failure(); + + for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { + // Check that the i'th operand is a broadcast. + auto broadcast = llvm::dyn_cast_or_null( + op->getOpOperand(i).get().getDefiningOp()); + if (!broadcast) continue; + + // Check that the operand of the broadcast has fully defined shape. + auto broadcast_arg_type = + broadcast.input().getType().dyn_cast_or_null(); + if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + + // Check that the other argument has fully defined shape. + auto argument_type = op->getOpOperand(1 - i) + .get() + .getType() + .dyn_cast_or_null(); + if (!argument_type || !argument_type.hasStaticShape()) continue; + + // Check that the input of the broadcast and the other operand is broadcast + // compatible. + llvm::SmallVector broadcasted_shape; + if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(), + argument_type.getShape(), + broadcasted_shape)) + continue; + + // Check that an implicit broadcast between the operand of the broadcast and + // the other argument would result in the same type as the result type. + if (broadcasted_shape != result_type.getShape()) continue; + + // Update the operand of the op to be the operand of the broadcast. + rewriter.updateRootInPlace( + op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); + return success(); + } + + return failure(); +} + +void BroadcastFoldPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert(); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +namespace TF { +std::unique_ptr> CreateBroadcastFoldPass() { + return absl::make_unique(); +} +} // namespace TF + +static PassRegistration pass( + "tf-broadcast-fold", + "Fold explicit broadcasts into the following operations if they support " + "implicit broadcasting on their operand."); + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index 859d3ffb23c..26c0126932c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -30,6 +31,18 @@ limitations under the License. namespace mlir { namespace tf_executor { +namespace { + +// Checks if a tf_executor.Graph can be pruned. +// For TensorFlow V1.0 compatibility: when importing a graph without providing +// feeds/fetches/targets we should not attempt to prune. The best approximation +// here is to check if the graph is of the "main" function and does not have the +// "tf.entry_function" attribute defined. +bool CanPruneGraph(FuncOp func) { + return func.getName() != "main" || + func.getAttrOfType("tf.entry_function") != nullptr; +} + // Visits an op's operand if it is an output of an Operation in the same // tf_executor.graph. void VisitOpOperand(GraphOp graph, Value operand, @@ -75,6 +88,8 @@ void VisitOp(GraphOp graph, Operation* op, } } +} // namespace + // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph) { // A graph has a single block which forms a DAG: operations that aren't @@ -107,15 +122,8 @@ namespace { // This transformation pass prunes a TF graph eliminating dead-nodes. struct GraphPruning : public PassWrapper { void runOnFunction() override { - getFunction().walk([](tf_executor::GraphOp graph) { - // For TensorFlow V1.0 compatibility: when importing a graph without - // providing feeds/fetches we should not attempt to prune. The best - // approximation here is to check if the graph does not have any fetched - // values. - if (!graph.GetFetch().getNumOperands()) return; - - PruneGraph(graph); - }); + if (!CanPruneGraph(getFunction())) return; + getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 002117f15d8..c93679ab7da 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -215,7 +215,7 @@ class LowerAddNOp : public RewritePattern { }; // Lowers DynamicStitch op with constant indices and with static input and -// output shapes using Reshape, UnPack and ConcatV2 op. +// output shapes using Reshape, UnPack and Pack op. // // %indices0 = "tf.Const"() {value = dense<4> : tensor} // %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : @@ -237,7 +237,7 @@ class LowerAddNOp : public RewritePattern { // : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, // tensor<2xf32>) // %axis = "tf.Const"() {value = dense<0> : tensor} -// %0 = "tf.ConcatV2"(items1#3, items1#2, items1#1, items1#0, %items0, %axis) +// %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis) // : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, // tensor<2xf32>, tensor) -> tensor<5x2xf32> // @@ -303,8 +303,7 @@ class LowerDynamicStitchOp : public OpRewritePattern { } } - auto axis = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp(op, op.getType(), values, axis); + rewriter.replaceOpWithNewOp(op, op.getType(), values); return success(); } }; @@ -477,6 +476,210 @@ class LowerInvertPermutationOp } }; +// Approximates lgamma using Lanczos' approximation from +// "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical +// Analysis series B. Vol. 1: +// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff +// + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// +// Coefficients for the Lanczos approximation of the gamma function. The +// coefficients are uniquely determined by the choice of g and n +// (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below +// correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were +// evaluated and [7, 9] seemed to be the least sensitive to the quality of the +// log function. In particular, [5, 7] is the only choice where -1.5e-5 <= +// lgamma(2) <= 1.5e-5 for a particularly inaccurate log function. +static constexpr double kLanczosGamma = 7; // aka g +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; + +class LowerLgammaOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::LgammaOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.x(); + TensorType original_tensor_type = op.x().getType().cast(); + + // The approximation is not precise enough for float16. Do the computation + // in float32 for that case. + TensorType tensor_type = original_tensor_type; + FloatType float_type = tensor_type.getElementType().cast(); + bool needs_cast = float_type.getWidth() < 32; + if (needs_cast) { + MLIRContext *context = rewriter.getContext(); + float_type = FloatType::getF32(context); + if (original_tensor_type.hasRank()) { + tensor_type = + RankedTensorType::get(original_tensor_type.getShape(), float_type); + } else { + tensor_type = UnrankedTensorType::get(float_type); + } + input = rewriter.create(loc, tensor_type, input); + } + + // Helper lambda function for creating a ConstOp for a tensor filled with + // the given constant float value. + auto create_const_op = [&rewriter, loc, tensor_type, + float_type](double value) { + return rewriter.create( + loc, DenseElementsAttr::get(tensor_type, + FloatAttr::get(float_type, value))); + }; + + Value one_half = create_const_op(0.5); + Value one = create_const_op(1.0); + Value infinity = create_const_op(std::numeric_limits::infinity()); + Value pi = create_const_op(M_PI); + Value log_pi = create_const_op(std::log(M_PI)); + Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2); + Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5); + Value log_lanczos_gamma_plus_one_half = + create_const_op(std::log(kLanczosGamma + 0.5)); + Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff); + + Value minus_input = rewriter.create(loc, input); + Value input_minus_one = rewriter.create(loc, input, one); + + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + Value need_to_reflect = rewriter.create(loc, input, one_half); + Type tensor_bool_type = need_to_reflect.getType(); + Value z = rewriter.create(loc, need_to_reflect, minus_input, + input_minus_one); + + Value x = base_lanczos_coeff; + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]); + Value index = create_const_op(static_cast(i)); + Value z_plus_index = rewriter.create(loc, z, index); + Value z_plus_index_plus_one = + rewriter.create(loc, z_plus_index, one); + Value incr = rewriter.create(loc, lanczos_coefficient, + z_plus_index_plus_one); + x = rewriter.create(loc, x, incr); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + Value t = rewriter.create(loc, lanczos_gamma_plus_one_half, z); + Value z_div_lanczos_gamma_plus_one_half = + rewriter.create(loc, z, lanczos_gamma_plus_one_half); + Value log1p_z_div_lanczos_gamma_plus_one_half = + rewriter.create(loc, z_div_lanczos_gamma_plus_one_half); + Value log_t = + rewriter.create(loc, log_lanczos_gamma_plus_one_half, + log1p_z_div_lanczos_gamma_plus_one_half); + + // Compute the final result (modulo reflection). t(z) may be large, and we + // need to be careful not to overflow to infinity in the first term of + // + // (z + 1/2) * log(t(z)) - t(z). + // + // Therefore we compute this as + // + // (z + 1/2 - t(z) / log(t(z))) * log(t(z)). + // + // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x); + Value t_div_log_t = rewriter.create(loc, t, log_t); + Value one_half_minus_t_div_log_t = + rewriter.create(loc, one_half, t_div_log_t); + Value z_plus_one_half_minus_t_div_log_t = + rewriter.create(loc, z, one_half_minus_t_div_log_t); + Value z_plus_one_half_minus_t_div_log_t_mul_log_t = + rewriter.create(loc, z_plus_one_half_minus_t_div_log_t, + log_t); + Value log_x = rewriter.create(loc, x); + Value log_y_rhs = rewriter.create( + loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x); + Value log_y = rewriter.create(loc, log_sqrt_two_pi, log_y_rhs); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This + // is more numerically accurate: It doesn't overflow to inf like pi * x can, + // and if x is an integer, it evaluates to 0 exactly, which is significant + // because we then take the log of this value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close + // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain + // [0, 1] is symmetric across the line Y=0.5. + Value abs_input = rewriter.create(loc, input); + Value abs_input_floor = rewriter.create(loc, abs_input); + Value abs_frac_input = + rewriter.create(loc, abs_input, abs_input_floor); + + // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve + // precision of pi * abs_frac_input for values of abs_frac_input close to 1. + Value one_minus_abs_frac_input = + rewriter.create(loc, one, abs_frac_input); + Value abs_frac_input_gt_one_half = + rewriter.create(loc, abs_frac_input, one_half); + Value reduced_frac_input = rewriter.create( + loc, abs_frac_input_gt_one_half, one_minus_abs_frac_input, + abs_frac_input); + Value pi_mul_reduced_frac_input = + rewriter.create(loc, pi, reduced_frac_input); + Value sin_pi_mul_reduced_frac_input = + rewriter.create(loc, pi_mul_reduced_frac_input); + Value reflection_denom = + rewriter.create(loc, sin_pi_mul_reduced_frac_input); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + Value is_finite = rewriter.create(loc, tensor_bool_type, + reflection_denom); + Value neg_reflection_denom = + rewriter.create(loc, reflection_denom); + Value log_pi_minus_reflection_denom = + rewriter.create(loc, log_pi, reflection_denom); + Value reflection_if_finite = + rewriter.create(loc, log_pi_minus_reflection_denom, log_y); + Value reflection = rewriter.create( + loc, is_finite, reflection_if_finite, neg_reflection_denom); + + Value result = rewriter.create(loc, need_to_reflect, + reflection, log_y); + + // lgamma(+/-inf) = +inf. + Value is_inf = rewriter.create(loc, tensor_bool_type, input); + result = rewriter.create(loc, is_inf, infinity, result); + + if (needs_cast) { + result = rewriter.create(loc, original_tensor_type, result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + // Lowers Pack op to ConcatV2 op after changing shape of the inputs with // ExpandDims op. // @@ -777,9 +980,9 @@ class Lower_UnaryOpsComposition void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + LowerDynamicStitchOp, LowerInvertPermutationOp, + LowerLgammaOp, LowerPackOp, LowerSpaceToBatchNDOp, + LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context); populateWithGenerated(context, *patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index bddc863ee60..fec4c20e98d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall< // TODO(hinsu): Support inputs of TensorList types. def LowerZerosLikeOp : - Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input), + Pat<(TF_ZerosLikeOp:$src_op + TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input), (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)), (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>; def LowerScatterNdOp : Pat<(TF_ScatterNdOp $indices, - TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape), - (TF_TensorScatterUpdateOp + TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape), + (TF_TensorScatterAddOp (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))), $indices, $updates)>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 4abebc4475d..ac844b925ce 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/core/lib/monitoring/gauge.h" namespace mlir { namespace TFDevice { @@ -37,6 +38,11 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement"; +auto* auto_outside_compilation_gauge = + tensorflow::monitoring::Gauge::New( + "/tensorflow/core/use_auto_outside_compilation", + "Tracks if auto outside compilation is enabled"); + // This pass marks unsupported ops in a device cluster with // `_xla_outside_compilation` attribute so the operations will run on the host // instead of the device. Unsupported ops are ops that can not be code @@ -200,6 +206,9 @@ LogicalResult MarkUncompilableOps( outside_compiled_cluster_counter++; } }); + if (outside_compiled_cluster_counter > 0) { + auto_outside_compilation_gauge->GetCell()->Set(true); + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index f477f38afc6..86eea50d744 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -16,66 +16,62 @@ limitations under the License. // This pass forms `tf_executor.island` per region of // `tf_device.parallel_execute`. // -// For example: +// For example, the following: +// +// %0 = tf_executor.island { +// tf_executor.yield +// } // %1:2 = tf_executor.island { // %2 = "tf.opA"(%arg0) : (tensor) -> tensor // tf_executor.yield %2 : tensor // } -// tf_executor.island() { -// "tf_device.parallel_execute"() ({ -// %3 = "tf.opB"() : () -> tensor -// tf_device.return %3 : tensor -// }, -// { +// %3:2 = tf_executor.island(%0) { +// %4 = "tf_device.parallel_execute"() ( { +// %5 = "tf.opB"() : () -> tensor +// tf_device.return %5 : tensor +// }, { // %5 = "tf.opC"(%1#0) : (tensor) -> tensor // tf_device.return // }) {} : () -> (tensor) +// tf_executor.yield %4 : tensor +// } +// tf_executor.fetch %3#0 : tensor +// +// gets lowered to: +// +// %0 = tf_executor.island { // tf_executor.yield // } -// tf_executor.fetch +// %1:2 = tf_executor.island { +// %2 = "tf.opA"(%arg0) : (tensor) -> tensor +// tf_executor.yield %2 : tensor +// } // -// Would become: -// %1:2 = tf_executor.island { -// %2 = "tf.opA"(%arg0) : (tensor) -> tensor -// tf_executor.yield %2 : tensor -// } +// // Island for the first region of above parallel_execute. +// %3:2 = tf_executor.island(%0) { +// %4 = "tf.opB"() : () -> tensor +// tf_executor.yield %4 : tensor +// } // -// // Input barrier sink island that forwards all inputs. -// %output_0, %control_1 = tf_executor.island { -// tf_executor.yield %1#0: tensor -// } +// // Island for the second region of above parallel_execute. +// %5 = tf_executor.island(%0) { +// %6 = "tf.opC"(%1#0) : (tensor) -> tensor +// tf_executor.yield +// } // -// // Island for the first region of above parallel_execute. -// %output_2, %control_3 = tf_executor.island(%control_1) { -// %3 = "tf.opB"() : () -> tensor -// tf_executor.yield %3 : tensor -// } -// -// // Island for the second region of above parallel_execute. -// %control_5 = tf_executor.island { -// %5 = "tf.opC"(%output_0) : (tensor) -> tensor -// tf_executor.yield -// } -// -// // Output barrier sink island that forwards all outputs. -// %output_5, %control_6 = tf_executor.island(%control_5) { -// tf_executor.yield %output_2 -// } +// tf_executor.fetch %3#0, %5 : tensor, !tf_executor.control // // When tf_device.parallel_execute op is enclosed after tf_device.replicate, // then this pass will run following `replicate-to-island` pass and // `tf-executor-break-up-islands` pass. #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -89,174 +85,117 @@ struct ParallelExecuteToIslandsPass }; // Convert parallel_execute op to a set of islands where each region of -// parallel_execute op becomes a separate island. This ensures that -// regions of parallel_execute op gets executed concurrently. -LogicalResult ExpandParallelExecuteToIslands( - tf_executor::IslandOp island_op, tf_executor::IslandOp input_sink_island, +// parallel_execute op becomes a separate island. This ensures that the regions +// of the parallel_execute op gets executed concurrently. +void ExpandParallelExecuteToIslands( + tf_executor::IslandOp island_op, tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder, - llvm::SmallVector* islands) { - const int num_executions = - parallel_execute_op.getOperation()->getNumRegions(); - llvm::SmallVector executions; - executions.reserve(num_executions); - builder->setInsertionPoint(island_op); + llvm::SmallVectorImpl& executes) { + const int num_regions = parallel_execute_op.getOperation()->getNumRegions(); + executes.reserve(num_regions); - auto control_type = tf_executor::ControlType::get(island_op.getContext()); - for (int i : llvm::seq(0, num_executions)) { - auto execute_region = - parallel_execute_op.GetRegionBlockWithIndex(i).getParent(); + for (int i : llvm::seq(0, num_regions)) { + Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i); - // If region does not have any inputs, then add explicit control dependency - // from the input sink island. This guarantees that all inputs of - // parallel_execute op must be materialized before any of the islands are - // executed. - llvm::SetVector region_inputs; - getUsedValuesDefinedAbove(*execute_region, region_inputs); - llvm::SmallVector execution_control_inputs; - if (region_inputs.empty() && input_sink_island) - execution_control_inputs.emplace_back(input_sink_island.control()); - - // Collect result types and operands. - Operation* terminator = execute_region->front().getTerminator(); - llvm::SmallVector output_types(terminator->getOperandTypes()); - - // Replace terminator with YieldOp as island op always ends with yield op. + // Replace terminator with tf_executor.YieldOp. + Operation* terminator = execute_block.getTerminator(); builder->setInsertionPoint(terminator); - builder->create(terminator->getLoc(), - terminator->getOperands()); + auto yield = builder->create( + terminator->getLoc(), terminator->getOperands()); terminator->erase(); // Create new island for each region. builder->setInsertionPoint(island_op); - auto execution_island = builder->create( - island_op.getLoc(), output_types, control_type, - execution_control_inputs); + auto execute_island = builder->create( + island_op.getLoc(), yield.getOperandTypes(), + island_op.control().getType(), island_op.controlInputs()); - // Move over tf_device.parallel_execute body region into newly a - // created island. - execution_island.body().takeBody(*execute_region); - islands->push_back(execution_island); + // Move over tf_device.parallel_execute body region into newly the created + // island. + execute_island.body().takeBody(*execute_block.getParent()); + executes.push_back(execute_island); } - - return success(); } -// Creates an island that works as input sync point for islands. This guarantees -// that all (implicitly captured) inputs of parallel_execute are materialized -// before any of the islands are executed. -tf_executor::IslandOp CreateInputBarrierIsland( - OpBuilder* builder, tf_executor::IslandOp island_op) { - builder->setInsertionPoint(island_op); - - llvm::SetVector all_inputs; - getUsedValuesDefinedAbove(island_op.body(), all_inputs); - - // Filter out values that are arguments and doesn't need to be part of the - // entry barrier. - llvm::SmallVector island_inputs; - llvm::SmallVector input_types; - island_inputs.reserve(all_inputs.size()); - input_types.reserve(all_inputs.size()); - for (Value val : all_inputs) { - if (!val.isa()) { - island_inputs.push_back(val); - input_types.push_back(val.getType()); - } - } - if (island_inputs.empty() && island_op.controlInputs().empty()) return {}; - - // Create new island for that forwards all inputs. - auto control_type = tf_executor::ControlType::get(island_op.getContext()); - auto input_sink_island = builder->create( - island_op.getLoc(), input_types, control_type, island_op.controlInputs()); - input_sink_island.body().push_back(new Block); - - for (auto input_index_and_value : llvm::enumerate(island_inputs)) { - int index = input_index_and_value.index(); - Value input_value = input_index_and_value.value(); - replaceAllUsesInRegionWith(input_value, input_sink_island.getResult(index), - island_op.body()); - } - - // Create YieldOp for the new input sink island. - builder->setInsertionPointToEnd(&input_sink_island.GetBody()); - builder->create(island_op.getLoc(), island_inputs); - return input_sink_island; -} - -// Creates an islands that works as output sync point. This guarantees that -// execution of all islands must be completed before op following -// parallel_execute runs. -tf_executor::IslandOp CreateOutputBarrierIsland( - OpBuilder* builder, tf_executor::IslandOp island_op, - llvm::SmallVectorImpl* islands) { - // Add control dependency to island operand if island output has no uses. - llvm::SmallVector island_operands; - for (auto& island : *islands) - if (island.use_empty()) island_operands.push_back(island.control()); - - // Create single island forwarding all island results. - builder->setInsertionPoint(island_op); - auto island_output_sink = builder->create( - island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()), - island_operands); - island_output_sink.body().push_back(new Block); - return island_output_sink; -} - -LogicalResult CreateIslandsFromParallelExecute( +void CreateIslandsFromParallelExecute( tf_executor::IslandOp island_op, tf_device::ParallelExecuteOp parallel_execute_op) { OpBuilder builder(island_op); - auto input_sink_island = CreateInputBarrierIsland(&builder, island_op); - // Create N islands where N is the number of regions inside parallel_execute - // op. - llvm::SmallVector islands; - auto result = ExpandParallelExecuteToIslands( - island_op, input_sink_island, parallel_execute_op, &builder, &islands); - if (failed(result)) return result; + // Create islands for each region of the parallel_execute op. + llvm::SmallVector executes; + ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder, + executes); - // Remap all results of parallel_execute op with outputs from newly - // created islands. + // Remap all results of parallel_execute op with outputs from newly created + // islands. llvm::SmallVector parallel_execute_outputs; parallel_execute_outputs.reserve( parallel_execute_op.getOperation()->getNumResults()); - for (auto island : islands) - for (auto output_value : island.outputs()) - parallel_execute_outputs.emplace_back(output_value); + for (auto& execute : executes) + parallel_execute_outputs.append(execute.outputs().begin(), + execute.outputs().end()); - parallel_execute_op.getOperation()->replaceAllUsesWith( - parallel_execute_outputs); + for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs)) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - auto island_output_sink = - CreateOutputBarrierIsland(&builder, island_op, &islands); + // Add sink island to pin all islands as a control dependency if there is a + // control dependency leading from the parallel_execute originally. + if (!island_op.control().use_empty()) { + llvm::SmallVector island_operands; + for (auto& execute : executes) island_operands.push_back(execute.control()); + + builder.setInsertionPoint(island_op); + auto island_sink = builder.create( + island_op.getLoc(), llvm::ArrayRef{}, + island_op.control().getType(), island_operands); + island_sink.body().push_back(new Block); + builder.setInsertionPointToEnd(&island_sink.GetBody()); + builder.create(island_op.getLoc(), + llvm::ArrayRef{}); + island_op.control().replaceAllUsesWith(island_sink.control()); + } + + // Islands with no uses should be pinned to a graph fetch so they still + // execute. + llvm::SmallVector unused_execute_controls; + for (auto& execute : executes) + if (execute.use_empty()) + unused_execute_controls.push_back(execute.control()); + + if (!unused_execute_controls.empty()) { + auto graph_op = island_op.getParentOfType(); + tf_executor::FetchOp fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector<8>(fetch.getOperands()); + fetches.append(unused_execute_controls.begin(), + unused_execute_controls.end()); + builder.setInsertionPoint(fetch); + builder.create(fetch.getLoc(), fetches); + fetch.erase(); + } - // Move island YieldOp over to new single island and remap island results. - island_op.GetYield().getOperation()->moveBefore( - &island_output_sink.GetBody(), island_output_sink.GetBody().begin()); - island_op.replaceAllUsesWith(island_output_sink); island_op.erase(); - - return success(); -} - -// Finds islands with a single `tf_device.parallel_execute` and create -// individual islands per region of parallel_execute. -void LowerSingleIslandParallelExecuteToIslands( - tf_executor::IslandOp island_op) { - if (!hasSingleElement(island_op.GetBody().without_terminator())) return; - - if (auto parallel_execute_op = llvm::dyn_cast( - &island_op.GetBody().front())) - CreateIslandsFromParallelExecute(island_op, parallel_execute_op); } void ParallelExecuteToIslandsPass::runOnFunction() { - getFunction().walk([&](tf_executor::IslandOp island_op) { - LowerSingleIslandParallelExecuteToIslands(island_op); + // Find islands with a single `tf_device.parallel_execute` and create + // individual islands per execute region of the parallel_execute. + llvm::SmallVector parallel_execute_op_islands; + getFunction().walk([&](tf_executor::GraphOp graph_op) { + for (auto island_op : graph_op.getOps()) { + if (!island_op.WrapsSingleOp()) continue; + + if (isa(&island_op.GetBody().front())) + parallel_execute_op_islands.push_back(island_op); + } }); + + for (tf_executor::IslandOp island_op : parallel_execute_op_islands) { + auto parallel_execute_op = + cast(island_op.GetBody().front()); + CreateIslandsFromParallelExecute(island_op, parallel_execute_op); + } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index a4ddb713ec0..4a12c80c8d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -84,6 +84,10 @@ std::unique_ptr> CreateGpuOpFusionPass(); std::unique_ptr> CreateTensorDeviceCopyConversionPass(); +// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they +// have built in broadcasting support. +std::unique_ptr> CreateBroadcastFoldPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 3cd316cf92d..e635f13f018 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -159,6 +159,26 @@ Type GetResourceSubtype(Value value) { return nullptr; } +// Replaces all `tf.VarIsInitializedOp` in a block with a constant true. +// TODO(b/171039585): Replace this with proper analysis of +// `tf.VarIsInitializedOp` in regards to resource writes and control flow. +void SetAllVarIsInitializedToTrue(Block* block) { + auto builder = OpBuilder::atBlockBegin(block); + TF::ConstOp const_true = nullptr; + for (auto op : + llvm::make_early_inc_range(block->getOps())) { + builder.setInsertionPoint(op); + if (!const_true) + const_true = builder.create( + op.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true)); + + op.is_initialized().replaceAllUsesWith(const_true); + op.erase(); + } +} + // Performs store-load forwarding. This effectively removes // 1) Any resource loads after a store to that same resource is done // 2) Any resource stores except the last one. @@ -181,8 +201,20 @@ void ForwardStoreToLoad(Block* block) { if (!last_store) continue; // Use stored value in last_store to replace all uses of current resource - // load's result, then erase this resource load. - read_variable_op.value().replaceAllUsesWith(last_store.value()); + // load's result, then erase this resource load. Add an intermediate + // CastOp if the shape of types doesn't exactly match. + Type read_type = read_variable_op.value().getType(); + if (read_type != last_store.value().getType()) { + OpBuilder builder(last_store); + builder.setInsertionPointAfter(last_store); + auto cast = builder.create( + last_store.getLoc(), read_type, last_store.value(), + /*Truncate=*/builder.getBoolAttr(false)); + read_variable_op.value().replaceAllUsesWith(cast); + } else { + read_variable_op.value().replaceAllUsesWith(last_store.value()); + } + read_variable_op.erase(); continue; } @@ -463,7 +495,7 @@ void RegionResourceHoister::AppendResourceStoreValueToReturn( auto new_return_operands = llvm::to_vector<4>(old_return->getOperands()); new_return_operands.resize(num_new_results_); - // initialize return values for written resources to be the hosited reads. + // initialize return values for written resources to be the hoisted reads. for (Value resource : written_resources_) { const ResourceInfo& info = resources_[resource]; new_return_operands[info.result_index] = info.hoisted_read; @@ -767,8 +799,6 @@ LogicalResult LiftArgRetResourcesForFunction( FuncOp func_op, const llvm::SmallDenseMap& resource_data_types, llvm::function_ref handle_updated_arg_value) { - ForwardStoreToLoad(&func_op.front()); - RegionResourceHoister hoister(func_op); if (failed(hoister.Analyze())) return failure(); @@ -1167,7 +1197,7 @@ void UpdatePartitionedCallOpWithNewCallee( } LogicalResult HoistForControlFlow( - Block*, ModuleOp, + Block*, ModuleOp, bool, llvm::SmallDenseMap*); // A templated routine for handling both PartitionedCallOp and @@ -1176,14 +1206,15 @@ LogicalResult HoistForControlFlow( // flow, then performs lifting on the callee. template LogicalResult HandlePartitionedCallOp( - CallOpType call_op, FuncOp callee, ModuleOp module, + CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized, llvm::SmallDenseMap* lifted_callees) { auto emplace_res = lifted_callees->try_emplace(callee.getName(), PartitionedCallLiftingInfo()); if (emplace_res.second) { // Unseen callee. Perform resource lifting on it. - if (failed(HoistForControlFlow(&callee.front(), module, lifted_callees))) + if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized, + lifted_callees))) return failure(); if (failed(HandlePartitionedCallOpCallee( @@ -1198,26 +1229,28 @@ LogicalResult HandlePartitionedCallOp( // Hoists resource loads/stores from control flow ops in `block` outside the // body/cond/branch/callee functions. LogicalResult HoistForControlFlow( - Block* block, ModuleOp module, + Block* block, ModuleOp module, bool vars_initialized, llvm::SmallDenseMap* lifted_partitioned_call_callees) { + if (vars_initialized) SetAllVarIsInitializedToTrue(block); + for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { auto body = while_op.body_function(); auto cond = while_op.cond_function(); // Recursively handle the nested control flow. - HoistForControlFlow(&body.front(), module, + HoistForControlFlow(&body.front(), module, vars_initialized, lifted_partitioned_call_callees); - HoistForControlFlow(&cond.front(), module, + HoistForControlFlow(&cond.front(), module, vars_initialized, lifted_partitioned_call_callees); if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { auto then_branch = if_op.then_function(); auto else_branch = if_op.else_function(); // Recursively handle the nested control flow. - HoistForControlFlow(&then_branch.front(), module, + HoistForControlFlow(&then_branch.front(), module, vars_initialized, lifted_partitioned_call_callees); - HoistForControlFlow(&else_branch.front(), module, + HoistForControlFlow(&else_branch.front(), module, vars_initialized, lifted_partitioned_call_callees); if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}))) return failure(); @@ -1226,7 +1259,7 @@ LogicalResult HoistForControlFlow( case_op.get_branch_functions(branch_functions); for (FuncOp func : branch_functions) { // Recursively handle the nested control flow. - HoistForControlFlow(&func.front(), module, + HoistForControlFlow(&func.front(), module, vars_initialized, lifted_partitioned_call_callees); } if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); @@ -1237,6 +1270,7 @@ LogicalResult HoistForControlFlow( "resource lifting does not support call with nested references."); } if (failed(HandlePartitionedCallOp(call_op, callee, module, + vars_initialized, lifted_partitioned_call_callees))) { // Nested control flow handling is done in HandlePartitionedCallOp(). return failure(); @@ -1244,12 +1278,13 @@ LogicalResult HoistForControlFlow( } else if (auto call_op = llvm::dyn_cast(&op)) { if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module, + vars_initialized, lifted_partitioned_call_callees))) { return failure(); } } else if (isa(op)) { for (Region& region : op.getRegions()) - HoistForControlFlow(®ion.front(), module, + HoistForControlFlow(®ion.front(), module, vars_initialized, lifted_partitioned_call_callees); LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op); if (failed(result)) return failure(); @@ -1277,7 +1312,8 @@ void ResourceOpLiftingPass::runOnOperation() { auto walk_result = module.walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::ClusterOp cluster) { LogicalResult result = HoistForControlFlow( - &cluster.GetBody(), module, &lifted_partitioned_call_callees); + &cluster.GetBody(), module, /*vars_initialized=*/true, + &lifted_partitioned_call_callees); if (failed(result)) return WalkResult::interrupt(); result = RegionResourceHoister::ReplaceOpWithNewOp(cluster); if (failed(result)) return WalkResult::interrupt(); @@ -1340,9 +1376,9 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { llvm::SmallDenseMap lifted_partitioned_call_callees; - if (failed(HoistForControlFlow(&function.front(), - cast(function.getParentOp()), - &lifted_partitioned_call_callees))) + if (failed(HoistForControlFlow( + &function.front(), cast(function.getParentOp()), + /*vars_initialized=*/false, &lifted_partitioned_call_callees))) return failure(); // Clean up and canonicalize to remove dead local variables as some local diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 6bd8baf9c99..05eef4d5045 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -464,6 +464,38 @@ LogicalResult HandleStackPopV2Op( return success(); } +LogicalResult HandleRegionControlFlowOps( + Operation& op, ModuleOp module, + llvm::SmallDenseMap* data_var_to_size_var, + llvm::StringMap* + decomposed_partitioned_call_callees) { + for (OpOperand& operand : op.getOpOperands()) { + if (getElementTypeOrSelf(operand.get().getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << operand.get().getType() + << " of operand #" << operand.getOperandNumber() + << ", resource type operands are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + for (OpResult result : op.getResults()) { + if (getElementTypeOrSelf(result.getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << result.getType() << " of result #" + << result.getResultNumber() + << ", resource type results are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + for (Region& region : op.getRegions()) { + if (failed(DecomposeStackOpsInternal(®ion.front(), module, + data_var_to_size_var, + decomposed_partitioned_call_callees))) + return failure(); + } + return success(); +} + // Decomposes stack ops on a region and recursively decomposes called functions. // data_var_to_size_var: a mapping from stacks' buffer local variables to size // local variables. @@ -505,6 +537,13 @@ LogicalResult DecomposeStackOpsInternal( decomposed_partitioned_call_callees))) { return failure(); } + } else if (llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op)) { + if (failed( + HandleRegionControlFlowOps(op, module, data_var_to_size_var, + decomposed_partitioned_call_callees))) + return failure(); } else if (auto pcall = llvm::dyn_cast(&op)) { if (!pcall.func()) { return pcall.emitOpError( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 680d5334ceb..8ad4687d537 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -458,38 +458,41 @@ llvm::SmallDenseMap> AccessedGradients( ArrayRef funcs, ModuleOp module) { llvm::SmallDenseMap> result; llvm::SmallDenseMap> result_sets; - auto insert = [&](Value v, const string& source) { - auto arg = v.cast(); - if (!arg) return; + auto insert = [&](Value v, const string& source, const Block& func_block) { + auto arg = v.dyn_cast(); + if (!arg || arg.getOwner() != &func_block) return; auto insert_res = result_sets[arg.getArgNumber()].insert(source); if (!insert_res.second) return; result[arg.getArgNumber()].push_back(source); }; for (FuncOp func : funcs) { - for (auto& op : func.front().getOperations()) { - if (llvm::isa(&op)) { - op.replaceAllUsesWith(op.getOperands()); - continue; + const Block& func_block = func.front(); + // Walk all operations and nested regions to find accessed gradient sources + // for function arguments. + func.walk([&](Operation* op) { + if (llvm::isa(op)) { + op->replaceAllUsesWith(op->getOperands()); + return; } - if (auto grad = llvm::dyn_cast(&op)) { - insert(grad.handle(), grad.source().str()); - } else if (auto while_op = llvm::dyn_cast(&op)) { + if (auto grad = llvm::dyn_cast(op)) { + insert(grad.handle(), grad.source().str(), func_block); + } else if (auto while_op = llvm::dyn_cast(op)) { for (const auto& entry : AccessedGradients( {while_op.body_function(), while_op.cond_function()}, module)) for (const string& source : entry.getSecond()) - insert(while_op.getOperand(entry.getFirst()), source); - } else if (auto if_op = llvm::dyn_cast(&op)) { + insert(while_op.getOperand(entry.getFirst()), source, func_block); + } else if (auto if_op = llvm::dyn_cast(op)) { for (const auto& entry : AccessedGradients( {if_op.then_function(), if_op.else_function()}, module)) for (const string& source : entry.getSecond()) - insert(if_op.getOperand(entry.getFirst() + 1), source); - } else if (auto call = llvm::dyn_cast(&op)) { + insert(if_op.getOperand(entry.getFirst() + 1), source, func_block); + } else if (auto call = llvm::dyn_cast(op)) { auto callee = dyn_cast(call.resolveCallable()); for (const auto& entry : AccessedGradients({callee}, module)) for (const string& source : entry.getSecond()) - insert(call.getArgOperands()[entry.getFirst()], source); + insert(call.getArgOperands()[entry.getFirst()], source, func_block); } - } + }); } return result; } @@ -810,6 +813,38 @@ LogicalResult HandlePartitionedCallOp( return success(); } +LogicalResult HandleRegionControlFlowOps( + Operation& op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::StringMap* + decomposed_partitioned_call_callees) { + for (OpOperand& operand : op.getOpOperands()) { + if (getElementTypeOrSelf(operand.get().getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << operand.get().getType() + << " of operand #" << operand.getOperandNumber() + << ", resource type operands are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + for (OpResult result : op.getResults()) { + if (getElementTypeOrSelf(result.getType()).isa()) { + return op.emitOpError() + << "found unexpected type " << result.getType() << " of result #" + << result.getResultNumber() + << ", resource type results are expected to have been " + "canonicalized away for region based control flow ops"; + } + } + + for (Region& region : op.getRegions()) { + if (failed(DecomposeTensorArrayOps(®ion.front(), module, stats, + decomposed_partitioned_call_callees))) + return failure(); + } + return success(); +} + LogicalResult DecomposeTensorArrayOps( Block* block, ModuleOp module, llvm::SmallDenseMap* stats, @@ -853,6 +888,12 @@ LogicalResult DecomposeTensorArrayOps( decomposed_partitioned_call_callees))) { return failure(); } + } else if (llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op)) { + if (failed(HandleRegionControlFlowOps( + op, module, stats, decomposed_partitioned_call_callees))) + return failure(); } else if (auto pcall = llvm::dyn_cast(&op)) { auto callee = pcall.func(); if (!callee) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index fff09ccf363..46bc094e5ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -172,6 +172,7 @@ bool ShouldMoveOpAfterCluster( const llvm::SmallSetVector& preceding_users, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis, const llvm::SmallDenseSet& observed_resource_ids) { + const bool is_replicate = llvm::isa(op); auto result = op->walk([&](Operation* inner_op) { for (Value operand : inner_op->getOperands()) { Operation* def = operand.getDefiningOp(); @@ -186,6 +187,11 @@ bool ShouldMoveOpAfterCluster( } } + // Don't visit replicate op inner op operands as new resource + // values/arguments may have been created but are not known in + // `resource_alias_analysis`. + if (is_replicate && inner_op != op) return WalkResult::advance(); + // Check for uses of any resource in or after cluster. for (Value operand : TF::filter_resources(inner_op->getOperands())) { if (resource_alias_analysis.IsUnknownResource(operand)) continue; @@ -424,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { Value result = result_and_idx.value(); int idx = result_and_idx.index(); - for (auto& use : result.getUses()) { - Operation* def = use.getOwner(); - if (!def || !llvm::isa(def)) - return cluster.emitError() - << "requires output of " << cluster.getOperationName() - << " to lead to a 'tf.TPUReplicatedOutput' op"; + auto replicate_outputs = llvm::make_range( + std::next(replicate_op.result_begin(), idx * num_replicas), + std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); - const int def_NumResults = def->getNumResults(); - if (def_NumResults != num_replicas) + for (auto& use : llvm::make_early_inc_range(result.getUses())) { + Operation* def = use.getOwner(); + if (!llvm::isa(def)) { + // If user is not a `tf.TPUReplicatedOutput`, simply forward the first + // replica output. Certain Graphs under V1 create `tf.Identity` users of + // replicated ops to pin the TPU computation for execution. + use.set(*replicate_outputs.begin()); + continue; + } + + const int def_num_results = def->getNumResults(); + if (def_num_results != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; - auto replicate_outputs = llvm::make_range( - std::next(replicate_op.result_begin(), idx * num_replicas), - std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); def->replaceAllUsesWith(replicate_outputs); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 0b9eaba8c97..35ad3d21b30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -47,133 +48,47 @@ struct TPUShardingIdentificationPass void runOnOperation() override; }; -// Sets `sharding_op` if `op` is XlaShardingOp or if XlaSharding op is adjacent -// to `op`. XlaSharding op may be direct user of inputs but it may also be -// followed by an Identity op and, in the case where bfloat16 type is used, Cast -// op may be added right after the input. As so, parse the users of the -// operation to access connected XlaSharding op. +// Finds XlaSharding op connected to an argument value. If value is a resource +// type then XlaSharding op will be connected to a ReadVariable op. XlaSharding +// op may be direct user of inputs but it may also be followed by an Identity op +// and, in the case where bfloat16 type is used, Cast op may be added right +// after the input. // +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, +// Case, While) ops and Caller return values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. -void GetAdjacentXlaShardingOp(Operation* op, - llvm::Optional* sharding_op) { - // TODO(hongjunchoi): Detect the case when sharding configuration is ambiguous - // for a single input (i.e. multiple different XlaSharding ops with different - // configuration policies are connected). - if (sharding_op->hasValue()) return; +llvm::Optional GetXlaShardingFromArg(const Value& value) { + llvm::SmallPtrSet visited_values; + llvm::SmallVector values_to_visit{value}; + while (!values_to_visit.empty()) { + llvm::SmallVector next_values_to_visit; + for (Value value_to_visit : values_to_visit) { + if (!visited_values.insert(value_to_visit).second) continue; - if (auto sharding = llvm::dyn_cast(op)) { - sharding_op->emplace(sharding); - return; - } + for (auto& use : value_to_visit.getUses()) { + Operation* owner = use.getOwner(); + if (auto sharding = llvm::dyn_cast(owner)) + return sharding._XlaSharding(); - if (llvm::isa(op)) { - for (auto user : op->getUsers()) - GetAdjacentXlaShardingOp(user, sharding_op); - } -} + if (llvm::isa(owner)) { + next_values_to_visit.push_back(use.getOwner()->getResult(0)); + continue; + } -// Parses XlaSharding op connected to input args. If Input to -// tf_device.ClusterFunc op is of resource type, then XlaSharding op will be -// connected to following ReadVariable op. -// -// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a Call op or -// If/While op. -llvm::Optional ParseInputSharding(const Value& arg) { - llvm::Optional parsed_sharding_op; - for (auto user : arg.getUsers()) { - if (parsed_sharding_op) continue; - - GetAdjacentXlaShardingOp(user, &parsed_sharding_op); - if (parsed_sharding_op) continue; - - if (llvm::isa(user)) - for (auto read_variable_user : user->getUsers()) - GetAdjacentXlaShardingOp(read_variable_user, &parsed_sharding_op); - } - - if (!parsed_sharding_op) return llvm::Optional(); - return parsed_sharding_op.getValue()._XlaSharding(); -} - -// Returns the provided sharding configuration if operand of return value of -// tf_device.ClusterFunc op is directly from XlaSharding op, -llvm::Optional ParseReturnValueSharding(FuncOp func, - const int output_index, - const OpOperand& operand) { - if (auto sharding_op = llvm::dyn_cast_or_null( - operand.get().getDefiningOp())) - return sharding_op._XlaSharding(); - - return llvm::Optional(); -} - -// Includes information on Func op and argument index of the input value. This -// is used to trace Value that is fed into function call ops. -struct FunctionAndArgumentInfo { - FuncOp func; - int argument_index; -}; - -// Adds tf.PartitionedCall op or tf.StatefulPartitionedCall op to `list`. If -// `op` is a function call op, then find the func op from provided `module` and -// add the func op with `arg_index` to `list`. `list` will later be used to -// trace mlir::Value that is fed into (potentially nested) function call ops. -void AddFunctionalOpsToList( - const int arg_index, ModuleOp module, Operation* op, - llvm::SmallVectorImpl* list) { - if (auto pcall_op = llvm::dyn_cast(op)) { - if (!pcall_op.f().isa()) return; - - auto pcall_func = llvm::cast( - module.lookupSymbol(pcall_op.f().getRootReference())); - assert(pcall_func); - list->emplace_back(FunctionAndArgumentInfo{pcall_func, arg_index}); - - } else if (auto spcall_op = - llvm::dyn_cast(op)) { - auto sp_call_func = llvm::cast(module.lookupSymbol(spcall_op.f())); - assert(sp_call_func); - list->emplace_back(FunctionAndArgumentInfo{sp_call_func, arg_index}); - } -} - -// Walks the MLIR graph from `arg` and return a list of all function call ops to -// which the `arg` op is directly connected. -// -// For example: -// argument0 -> PartitionedCallOp -> StatefulPartitionedCallOp -> AddOp -// -// For above case, PartitionedCall op and StatefulPartitionedCallOp will be -// returned. -llvm::SmallVector ExtractFunctionsConnectedToArg( - BlockArgument arg, ModuleOp module) { - llvm::SmallVector functions_connected_to_arg; - for (auto& arg_use : arg.getUses()) - AddFunctionalOpsToList(arg_use.getOperandNumber(), module, - arg_use.getOwner(), &functions_connected_to_arg); - - llvm::SmallVector functions_to_parse{ - functions_connected_to_arg.begin(), functions_connected_to_arg.end()}; - - while (!functions_to_parse.empty()) { - llvm::SmallVector newly_discovered_functions; - for (auto function_info : functions_to_parse) { - Block& func_entry_block = function_info.func.front(); - auto argument = - func_entry_block.getArgument(function_info.argument_index); - - for (auto& arg_use : argument.getUses()) - AddFunctionalOpsToList(arg_use.getOperandNumber(), module, - arg_use.getOwner(), &newly_discovered_functions); + if (auto call_op = llvm::dyn_cast(owner)) { + FuncOp func = llvm::dyn_cast(call_op.resolveCallable()); + if (!func) continue; + next_values_to_visit.push_back( + func.getArgument(use.getOperandNumber())); + } + } } - functions_connected_to_arg.append(newly_discovered_functions.begin(), - newly_discovered_functions.end()); - std::swap(functions_to_parse, newly_discovered_functions); + values_to_visit.swap(next_values_to_visit); } - return functions_connected_to_arg; + return llvm::None; } // Walks the graph from the arguments of the `cluster_func_op` and extracts @@ -186,7 +101,6 @@ void IdentifyXlaShardingForComputationInputs( FuncOp cluster_function, Builder* builder) { // Look up function definition from module. Block& cluster_function_block = cluster_function.front(); - ModuleOp module = cluster_func_op.getParentOfType(); llvm::SmallVector sharding_for_args( cluster_function_block.getNumArguments(), logical_core_0_sharding); @@ -202,31 +116,17 @@ void IdentifyXlaShardingForComputationInputs( // Sharding configurations are added to the tf_device.ClusterFunc as an // attribute and the function as an argument attribute. for (auto& arg : cluster_function_block.getArguments()) { - auto arg_sharding = ParseInputSharding(arg); - const int arg_index_to_tpu_computation = arg.getArgNumber(); - - if (!arg_sharding.hasValue()) { - auto connected_functions_to_arg = - ExtractFunctionsConnectedToArg(arg, module); - for (auto& function_arg_info : connected_functions_to_arg) { - if (arg_sharding.hasValue()) break; - - const int function_argument_index = function_arg_info.argument_index; - auto& parsed_function = function_arg_info.func; - Block& parsed_function_block = parsed_function.front(); - arg_sharding = ParseInputSharding( - parsed_function_block.getArgument(function_argument_index)); - } - } + auto arg_sharding = GetXlaShardingFromArg(arg); + const int index = arg.getArgNumber(); if (arg_sharding) { - sharding_for_args[arg_index_to_tpu_computation] = arg_sharding.getValue(); + sharding_for_args[index] = arg_sharding.getValue(); cluster_function.setArgAttr( - arg_index_to_tpu_computation, kShardingAttr, + index, kShardingAttr, builder->getStringAttr(arg_sharding.getValue())); } else { cluster_function.setArgAttr( - arg_index_to_tpu_computation, kShardingAttr, + index, kShardingAttr, builder->getStringAttr(logical_core_0_sharding)); } } @@ -235,6 +135,44 @@ void IdentifyXlaShardingForComputationInputs( builder->getStrArrayAttr(sharding_for_args)); } +// Finds XlaSharding op connected to a result value. XlaSharding op may be +// direct user of inputs but it may also be followed by an Identity op and, in +// the case where bfloat16 type is used, Cast op may be added right after the +// input. +// +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, +// Case, While) ops and Caller argument values. +// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded +// inputs. +llvm::Optional GetXlaShardingFromRetval(const Value& value) { + llvm::SmallPtrSet visited_values; + Value value_to_visit = value; + while (value_to_visit) { + if (!visited_values.insert(value_to_visit).second) return llvm::None; + + Operation* def = value_to_visit.getDefiningOp(); + if (auto sharding = llvm::dyn_cast_or_null(def)) + return sharding._XlaSharding(); + + if (llvm::isa_and_nonnull(def)) { + value_to_visit = def->getOperand(0); + continue; + } + + if (auto call_op = llvm::dyn_cast_or_null(def)) { + FuncOp func = llvm::dyn_cast(call_op.resolveCallable()); + if (!func) continue; + value_to_visit = func.front().getTerminator()->getOperand( + value_to_visit.cast().getResultNumber()); + continue; + } + + break; + } + + return llvm::None; +} + // Parses XlaSharding op directly connected from the outputs of the // `cluster_func` and extract sharding configurations for outputs. void IdentifyXlaShardingForComputationOutputs( @@ -252,8 +190,8 @@ void IdentifyXlaShardingForComputationOutputs( // tf_device.ClusterFunc as an attribute and the function as a result // attribute. for (auto& ret : terminator->getOpOperands()) { + auto ret_sharding = GetXlaShardingFromRetval(ret.get()); const int index = ret.getOperandNumber(); - auto ret_sharding = ParseReturnValueSharding(func, index, ret); if (ret_sharding) { sharding_for_rets[index] = ret_sharding.getValue(); @@ -264,6 +202,7 @@ void IdentifyXlaShardingForComputationOutputs( builder->getStringAttr(logical_core_0_sharding)); } } + cluster_func.setAttr(tensorflow::kOutputShardingAttr, builder->getStrArrayAttr(sharding_for_rets)); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 42ce5c533a2..efbbc43967c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -104,6 +104,7 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" @@ -130,6 +131,10 @@ using stream_executor::port::StatusOr; namespace { +auto* reference_variable_gauge = tensorflow::monitoring::Gauge::New( + "/tensorflow/core/uses_reference_variables", + "Tracks if reference variables are used anywhere in the graph"); + constexpr char kTpuReplicateAttr[] = "_tpu_replicate"; bool IsOutputShapesAttribute(const AttrValue& attr_value, @@ -2057,6 +2062,11 @@ class GraphDefImporter : public ImporterBase { llvm::StringRef func_name); private: + // Checks if a Module contains any ref variables in any operation operands + // or results, including checking Block arguments and operations within + // regions. + static bool ModuleContainsRefType(mlir::ModuleOp module); + explicit GraphDefImporter( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, const GraphImportConfig& specs, mlir::ModuleOp module, @@ -2092,6 +2102,38 @@ class GraphDefImporter : public ImporterBase { absl::InlinedVector* control_ret_nodes); }; +bool IsTensorFlowRefType(mlir::Type ty) { + return mlir::getElementTypeOrSelf(ty).isa(); +} + +bool OpHasRefTypeOperandOrResult(mlir::Operation* op) { + // Check op operands. + for (mlir::Type ty : op->getOperandTypes()) + if (IsTensorFlowRefType(ty)) return true; + // Check op results. + for (mlir::Type ty : op->getResultTypes()) + if (IsTensorFlowRefType(ty)) return true; + // Check all block arguments within any regions the op has. + for (mlir::Region& region : op->getRegions()) + for (mlir::Block& block : region) + for (auto& arg : block.getArguments()) + if (IsTensorFlowRefType(arg.getType())) return true; + return false; +} + +bool GraphDefImporter::ModuleContainsRefType(mlir::ModuleOp module) { + // If walk is interrupted at any point, that means a ref variable was found. + // At this point, we've confirmed existence of a ref variable and don't need + // to continue looking. + return module + .walk([&](mlir::Operation* op) { + if (OpHasRefTypeOperandOrResult(op)) + return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }) + .wasInterrupted(); +} + StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, @@ -2126,28 +2168,27 @@ StatusOr GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, &control_ret_nodes)); - if (!arg_nodes.empty() || !ret_nodes.empty() || - !control_ret_nodes.empty()) { - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - auto node_name = [&](const OutputTensor& tensor) { - ss << tensor.node->name(); - }; - llvm::interleave(arg_nodes, ss, node_name, ","); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(ret_nodes, ss, node_name, ","); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.control_outputs, ss, ","); - auto control_outputs = - b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + auto node_name = [&](const OutputTensor& tensor) { + ss << tensor.node->name(); + }; + llvm::interleave(arg_nodes, ss, node_name, ","); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(ret_nodes, ss, node_name, ","); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.control_outputs, ss, ","); + auto control_outputs = + b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); - attrs.push_back(b.getNamedAttr( - "tf.entry_function", - b.getDictionaryAttr({inputs, outputs, control_outputs}))); - } + // Under `graph_as_function` mode, `tf.entry_function` is always set as it + // is assumed feed, fetch, and target nodes are set correctly. + attrs.push_back(b.getNamedAttr( + "tf.entry_function", + b.getDictionaryAttr({inputs, outputs, control_outputs}))); } else { // Collects the argument and return nodes by looking up the node names // specified by the user. @@ -2190,6 +2231,13 @@ StatusOr GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs)); + // Check if there are any reference variables in the module. + bool contains_ref_var = ModuleContainsRefType(*module); + reference_variable_gauge->GetCell()->Set(contains_ref_var); + if (contains_ref_var) { + VLOG(1) << "Graph contains one or more reference variables"; + } + // Mark main function public, others private. for (auto function : module.get().getOps()) { auto visibility = function.getName() == func_name diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc index 4792e220b17..a51a3697b1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc @@ -21,11 +21,6 @@ Status GenerateResourceSharedNameIfEmpty(Graph& graph, FunctionLibraryDefinition& flib_def) { auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def, const OpDef& op_def) { - // Only upgrade when it is a resource handle op. - if (op_def.output_arg().size() != 1 || - op_def.output_arg(0).type() != tensorflow::DT_RESOURCE) - return false; - // If the OpDef has "use_node_name_sharing" field, then it is valid to use // node names as shared names. if (!std::any_of(op_def.attr().begin(), op_def.attr().end(), diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index b55a5aa5243..13804e324ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -277,10 +277,6 @@ void CreateConvertMlirToXlaHloPipeline( pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); - - // TODO(b/159127949): Stack and TensorArray decomposition passes do not handle - // region based control flow yet. So convert back to functional control flow. - pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); pm.addPass(mlir::TF::CreateStackOpsDecompositionPass()); pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -289,6 +285,10 @@ void CreateConvertMlirToXlaHloPipeline( // Guarantee all functions have one use, which enables shape inference. pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // TODO(b/171426148): We cannot completely remove region to functional control + // flow conversion from this pipeline yet as it causes some unit tests to + // fail. + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); // LegalizeTFControlFlow encapsulates arguments for control flow operations // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td index c5a059e5b6b..c17939fd962 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td @@ -23,10 +23,6 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Checks if the value has only one user. def HasOneUse : Constraint>; -// Constraint that makes sure both operands are the same operands. -// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed. -def EqualOperands : Constraint>; - // Checks if the operand0's rank is one less than operand1's rank. def PReluAlphaRankCheck : Constraint< CPred<"$0.getType().cast().getRank() == " @@ -36,13 +32,12 @@ def PReluAlphaRankCheck : Constraint< // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) def : Pat<(TF_AddV2Op - (TF_ReluOp:$relu_out $input1), + (TF_ReluOp:$relu_out $x), (TF_MulOp:$mul_out - (TF_ReluOp (TF_NegOp:$input_neg_out $input2)), + (TF_ReluOp (TF_NegOp:$input_neg_out $x)), $neg_alpha)), - (TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)), - [(EqualOperands $input1, $input2), - (PReluAlphaRankCheck $neg_alpha, $input1), + (TFJS_PReluOp $x, (TF_NegOp $neg_alpha)), + [(PReluAlphaRankCheck $neg_alpha, $x), (HasOneUse $relu_out), (HasOneUse $mul_out), (HasOneUse $input_neg_out) diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index fbdc2d7d93b..ae02faf8be2 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -1,10 +1,17 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_py_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") load( "//third_party/mlir:tblgen.bzl", "gentbl", ) load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( default_visibility = [ @@ -37,10 +44,12 @@ filegroup( "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], + compatible_with = get_compatible_with_cloud(), ) gentbl( name = "tfr_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), tbl_outs = [ ( "-gen-op-decls", @@ -114,7 +123,9 @@ cc_library( ":tfr", ":utils", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -219,6 +230,7 @@ cc_library( deps = [ ":tfr_decompose_ctx", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/core:lib", "//tensorflow/stream_executor/lib", "@llvm-project//mlir:IR", ], @@ -232,9 +244,9 @@ tf_py_test( data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], python_version = "PY3", tags = [ + "no_pip", "no_windows", # TODO(b/170752141) "nomac", # TODO(b/170752141) - "notsan", # TODO(b/170752141) ], deps = [ "//tensorflow/compiler/mlir/tfr/resources:composite_ops", @@ -248,6 +260,7 @@ cc_library( hdrs = ["integration/node_expansion_pass.h"], deps = [ ":tfr_decompose_ctx", + "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", "//tensorflow/stream_executor/lib", @@ -263,6 +276,7 @@ tf_py_test( data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], python_version = "PY3", tags = [ + "no_pip", "no_windows", # TODO(b/170752141) "nomac", # TODO(b/170752141) ], @@ -304,10 +318,6 @@ py_library( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/tfr:tfr_wrapper", - "//tensorflow/python:op_def_registry", - "//tensorflow/python/autograph/pyct", - "//tensorflow/python/autograph/pyct/static_analysis", - "@gast_archive//:gast", ], ) @@ -323,7 +333,6 @@ tf_py_test( "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", "//tensorflow/compiler/mlir/tfr/resources:test_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", ], ) @@ -334,9 +343,6 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/python:op_def_registry", - "//tensorflow/python/autograph/pyct", - "@gast_archive//:gast", ], ) @@ -345,11 +351,24 @@ py_test( size = "small", srcs = ["python/op_reg_gen_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":composite", ":op_reg_gen", "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", - "//tensorflow/python:client_testlib", ], ) + +py_library( + name = "test_utils", + srcs = ["python/test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +gen_op_libraries( + name = "one_op", + src = "define_op_template.py", +) diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl new file mode 100644 index 00000000000..c00b5c88eee --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -0,0 +1,151 @@ +"""BUILD extension for TF composition project.""" + +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_custom_op_library", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + +def gen_op_libraries( + name, + src, + deps = [], + tags = [], + test = False): + """gen_op_libraries() generates all cc and py libraries for composite op source. + + Args: + name: used as the name component of all the generated libraries. + src: File contains the composite ops. + deps: Libraries the 'src' depends on. + tags: + test: + """ + if not src.endswith(".py") or name == src[:-3]: + fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) + + gen_op_lib_exec = src[:-3] # Strip off the .py + py_binary( + name = gen_op_lib_exec, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ] + deps, + ) + + registed_op = "registed_" + name + native.genrule( + name = registed_op, + srcs = [], + outs = [name + ".inc.cc"], + cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, + exec_tools = [":" + gen_op_lib_exec], + tags = tags, + ) + + native.cc_library( + name = name + "_cc", + testonly = test, + srcs = [":" + registed_op], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, + ) + + tf_custom_op_library( + name = name + ".so", + srcs = [":" + registed_op], + ) + + tf_gen_op_wrapper_py( + name = "gen_" + name, + out = "gen_" + name + ".py", + deps = [ + ":%s_cc" % name, + ], + ) + + tf_custom_op_py_library( + name = name, + dso = [":%s.so" % name], + kernels = [":%s_cc" % name], + srcs_version = "PY2AND3", + deps = [ + ":gen_%s" % name, + ], + ) + + # Link the register op and rebuild the binary + gen_tfr_lib_exec = gen_op_lib_exec + "_with_op_library" + py_binary( + name = gen_tfr_lib_exec, + main = src, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ":%s" % name, + ] + deps, + ) + + native.genrule( + name = name + "_mlir", + srcs = [], + outs = [name + ".mlir"], + cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, + exec_tools = [":" + gen_tfr_lib_exec], + tags = tags, + ) + + native.py_library( + name = name + "_py", + srcs = [src], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/compiler/mlir/tfr:op_reg_gen", + "//tensorflow/compiler/mlir/tfr:tfr_gen", + "//tensorflow/compiler/mlir/tfr:composite", + ] + deps, + ) + +def gen_op_bindings(name): + native.cc_library( + name = name + "_ops_cc", + srcs = [name + "_ops.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, + ) + + tf_custom_op_library( + name = name + "_ops.so", + srcs = [name + "_ops.cc"], + ) + + tf_gen_op_wrapper_py( + name = "gen_" + name + "_ops", + out = "gen_" + name + "_ops.py", + deps = [ + ":" + name + "_ops_cc", + ], + ) + + tf_custom_op_py_library( + name = name + "_ops", + dso = [":" + name + "_ops.so"], + kernels = [":" + name + "_ops_cc"], + visibility = ["//visibility:public"], + deps = [ + ":gen_" + name + "_ops", + ], + ) diff --git a/tensorflow/compiler/mlir/tfr/define_op_template.py b/tensorflow/compiler/mlir/tfr/define_op_template.py new file mode 100644 index 00000000000..c0db2981d2d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/define_op_template.py @@ -0,0 +1,64 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A template to define composite ops.""" + +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from tensorflow.compiler.mlir.tfr.python.composite import Composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + +flags.mark_flag_as_required('output') + + +@Composite('TestRandom', derived_attrs=['T: numbertype'], outputs=['o: T']) +def _composite_random_op(): + pass + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_') + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD new file mode 100644 index 00000000000..eeaee926c87 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -0,0 +1,60 @@ +load("//tensorflow:tensorflow.bzl", "py_binary") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//tensorflow/compiler/mlir/tfr/...", + ], +) + +gen_op_libraries( + name = "mnist_ops", + src = "ops_defs.py", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "mnist_ops_test", + size = "small", + srcs = ["mnist_ops_test.py"], + data = [":mnist_ops_mlir"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + ":mnist_ops", + ":mnist_ops_py", + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:test_utils", + ], +) + +py_binary( + name = "mnist_train", + srcs = ["mnist_train.py"], + data = [":mnist_ops_mlir"], + python_version = "PY3", + deps = [ + ":mnist_ops", + ":mnist_ops_py", + "//tensorflow:tensorflow_py", + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py new file mode 100644 index 00000000000..d25b424279f --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_ops_test.py @@ -0,0 +1,126 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops +from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs +from tensorflow.compiler.mlir.tfr.python import test_utils +from tensorflow.python.framework import load_library +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_mnist_ops.__file__) +_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class MnistOpsDefsTest(test_utils.OpsDefsTest): + + def test_new_conv2d_relu(self): + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'RELU' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_conv2d_relu6(self): + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'RELU6' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_conv2d_tanh(self): + self.skipTest('Fix tanh gradients') + input_ = tf.random.uniform([1, 4, 4, 1]) + filter_ = tf.random.uniform([2, 2, 1, 8]) + bias = tf.zeros([8]) + kwargs = { + 'input_': input_, + 'filter_': filter_, + 'bias': bias, + 'stride_w': 2, + 'stride_h': 2, + 'dilation_w': 1, + 'dilation_h': 1, + 'padding': 'SAME', + 'act': 'TANH' + } + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_conv2d), + ops_defs._composite_conv_add_relu, kwargs) + + def test_new_fully_connected(self): + input_ = tf.random.uniform([2, 4]) + filter_ = tf.random.uniform([3, 4]) + bias = tf.zeros([3]) + kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'} + + self._assertOpAndComposite([input_, filter_, bias], + tf.function(gen_mnist_ops.new_fully_connected), + ops_defs._composite_fully_connected, kwargs) + + def test_new_max_pool(self): + input_ = tf.random.uniform([8, 4, 4, 1]) + kwargs = { + 'input_': input_, + 'stride_w': 2, + 'stride_h': 2, + 'filter_width': 1, + 'filter_height': 1, + 'padding': 'SAME', + } + + self._assertOpAndComposite([input_], + tf.function(gen_mnist_ops.new_max_pool), + ops_defs._composite_max_pool, kwargs) + + +if __name__ == '__main__': + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' + test.main() diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py new file mode 100644 index 00000000000..a4adcf86d5b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/mnist_train.py @@ -0,0 +1,179 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MNIST model float training script with TensorFlow graph execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl import app +from absl import flags + +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops +from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs # pylint: disable=unused-import +from tensorflow.python.framework import load_library + +flags.DEFINE_integer('train_steps', 200, 'Number of steps in training.') + +_lib_dir = os.path.dirname(gen_mnist_ops.__file__) +_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + +# MNIST dataset parameters. +num_classes = 10 # total classes (0-9 digits). +num_features = 784 # data features (img shape: 28*28). +num_channels = 1 + +# Training parameters. +learning_rate = 0.01 +display_step = 10 +batch_size = 128 + +# Network parameters. +n_hidden_1 = 32 # 1st conv layer number of neurons. +n_hidden_2 = 64 # 2nd conv layer number of neurons. +n_hidden_3 = 1024 # 1st fully connected layer of neurons. +flatten_size = num_features // 16 * n_hidden_2 + +seed = 66478 + +weights = { + 'f1': + tf.Variable( + tf.random.truncated_normal([5, 5, num_channels, n_hidden_1], + stddev=0.1, + seed=seed)), + 'f2': + tf.Variable( + tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2], + stddev=0.1, + seed=seed)), + 'f3': + tf.Variable( + tf.random.truncated_normal([n_hidden_3, flatten_size], + stddev=0.1, + seed=seed)), + 'f4': + tf.Variable( + tf.random.truncated_normal([num_classes, n_hidden_3], + stddev=0.1, + seed=seed)), +} + +biases = { + 'b1': tf.Variable(tf.zeros([n_hidden_1])), + 'b2': tf.Variable(tf.zeros([n_hidden_2])), + 'b3': tf.Variable(tf.zeros([n_hidden_3])), + 'b4': tf.Variable(tf.zeros([num_classes])), +} + + +class FloatModel(tf.Module): + """Float inference for mnist model.""" + + @tf.function + def __call__(self, data): + """The Model definition.""" + x = tf.reshape(data, [-1, 28, 28, 1]) + + # 2D convolution, with 'SAME' padding (i.e. the output feature map has + # the same size as the input). + + # NOTE: The data/x/input is always specified in floating point precision. + # output shape: [-1, 28, 28, 32] + conv1 = gen_mnist_ops.new_conv2d(x, weights['f1'], biases['b1'], 1, 1, 1, 1, + 'SAME', 'RELU') + + # Max pooling. The kernel size spec {ksize} also follows the layout of + # the data. Here we have a pooling window of 2, and a stride of 2. + # output shape: [-1, 14, 14, 32] + max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME') + + # output shape: [-1, 14, 14, 64] + conv2 = gen_mnist_ops.new_conv2d(max_pool1, weights['f2'], biases['b2'], 1, + 1, 1, 1, 'SAME', 'RELU') + + # output shape: [-1, 7, 7, 64] + max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME') + + # Reshape the feature map cuboid into a 2D matrix to feed it to the + # fully connected layers. + # output shape: [-1, 7*7*64] + reshape = tf.reshape(max_pool2, [-1, flatten_size]) + + # output shape: [-1, 1024] + fc1 = gen_mnist_ops.new_fully_connected(reshape, weights['f3'], + biases['b3'], 'RELU') + # output shape: [-1, 10] + return gen_mnist_ops.new_fully_connected(fc1, weights['f4'], biases['b4']) + + +def grad(model, inputs, labels, trainable_variables): + with tf.GradientTape() as tape: + logits = model(inputs) + loss_value = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels, logits)) + grads = tape.gradient(loss_value, trainable_variables) + correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + return accuracy, loss_value, grads + + +def training_step(model, inputs, labels, optimizer, step): + trainable_variables = list(weights.values()) + list(biases.values()) + accuracy, loss_value, grads = grad(model, inputs, labels, trainable_variables) + if step % display_step == 0: + print('Step %d:' % step) + print(' Loss = %f' % loss_value) + print(' Batch accuracy: %f' % accuracy) + optimizer.apply_gradients(zip(grads, trainable_variables)) + + +def get_next_batch(iter_): + features = next(iter_) + images, labels = features['image'], features['label'] + return (mnist_preprocess(images), tf.one_hot(labels, num_classes)) + + +def mnist_preprocess(x): + x_float = tf.cast(x, tf.float32) + return x_float / 255.0 + + +def train(model, dataset, optimizer): + iter_ = iter(dataset) + for step in range(flags.FLAGS.train_steps): + inputs, labels = get_next_batch(iter_) + training_step(model, inputs, labels, optimizer, step) + + +def main(_): + # TODO(fengliuai): put this in some automatically generated code. + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' + # Create an mnist float model with the specified float state. + model = FloatModel() + optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) + + ds_train = tfds.load('mnist', split='train', shuffle_files=True) + ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64) + + train(model, ds_train, optimizer) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py new file mode 100644 index 00000000000..0cf4678892e --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/ops_defs.py @@ -0,0 +1,217 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines all the new composite ops used in the mnist example.""" + +# pylint: disable=g-direct-tensorflow-import +# pylint: disable=missing-function-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +Composite = composite.Composite +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + + +@Composite( + 'NewConv2D', + inputs=['input_: T', 'filter_: T', 'bias: T'], + attrs=[ + 'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int', + 'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""' + ], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h, + dilation_w, dilation_h, padding, act): + res = tf.raw_ops.Conv2D( + input=input_, + filter=filter_, + strides=[1, stride_w, stride_h, 1], + dilations=[1, dilation_w, dilation_h, 1], + padding=padding) + res = tf.raw_ops.Add(x=res, y=bias) + if act == 'RELU': + return tf.raw_ops.Relu(features=res) + elif act == 'RELU6': + return tf.raw_ops.Relu6(features=res) + elif act == 'TANH': + return tf.raw_ops.Tanh(x=res) + else: + return res + + +@tf.RegisterGradient('NewConv2D') +def _conv_add_relu_grad(op, grad): + act = op.get_attr('act') + y = op.outputs[0] + if act == 'RELU': + grad = gen_nn_ops.relu_grad(grad, y) + elif act == 'RELU6': + grad = gen_nn_ops.relu6_grad(grad, y) + elif act == 'TANH': + y = math_ops.conj(y) + grad = gen_math_ops.tanh_grad(y, grad) + + broadcast_shape = tf.shape(y) + input_value_shape = tf.shape(op.inputs[2]) + _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( + s0=broadcast_shape, s1=input_value_shape) + updates_grad_reshaped = tf.reduce_sum( + grad, axis=reduction_axes, keepdims=True) + bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) + + dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1] + strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1] + padding = op.get_attr('padding') + shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]]) + return [ + tf.compat.v1.nn.conv2d_backprop_input( + shape_0, + op.inputs[1], + grad, + strides=strides, + padding=padding, + dilations=dilations, + data_format='NHWC'), + tf.compat.v1.nn.conv2d_backprop_filter( + op.inputs[0], + shape_1, + grad, + strides=strides, + padding=padding, + dilations=dilations, + data_format='NHWC'), bias_grad + ] + + +@Composite( + 'NewFullyConnected', + inputs=['input_: T', 'filter_: T', 'bias: T'], + attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_fully_connected(input_, filter_, bias, act): + res = tf.raw_ops.MatMul( + a=input_, b=filter_, transpose_a=False, transpose_b=True) + res = tf.raw_ops.Add(x=res, y=bias) + if act == 'RELU': + return tf.raw_ops.Relu(features=res) + elif act == 'RELU6': + return tf.raw_ops.Relu6(features=res) + elif act == 'TANH': + return tf.raw_ops.Tanh(x=res) + else: + return res + + +@tf.RegisterGradient('NewFullyConnected') +def _fully_connected_grad(op, grad): + act = op.get_attr('act') + y = op.outputs[0] + if act == 'RELU': + grad = gen_nn_ops.relu_grad(grad, y) + elif act == 'RELU6': + grad = gen_nn_ops.relu6_grad(grad, y) + elif act == 'TANH': + y = math_ops.conj(y) + grad = gen_math_ops.tanh_grad(y, grad) + + broadcast_shape = tf.shape(y) + input_value_shape = tf.shape(op.inputs[2]) + _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( + s0=broadcast_shape, s1=input_value_shape) + updates_grad_reshaped = tf.reduce_sum( + grad, axis=reduction_axes, keepdims=True) + bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) + + a = math_ops.conj(op.inputs[0]) + b = math_ops.conj(op.inputs[1]) + grad_a = gen_math_ops.mat_mul(grad, b) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + return [grad_a, grad_b, bias_grad] + + +@Composite( + 'NewMaxPool', + inputs=['input_: T'], + attrs=[ + 'stride_w: int', 'stride_h: int', 'filter_width: int', + 'filter_height: int', 'padding: {"SAME", "VALID"}' + ], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height, + padding): + ksize = [1, filter_width, filter_height, 1] + strides = [1, stride_w, stride_h, 1] + return tf.raw_ops.MaxPool( + input=input_, ksize=ksize, strides=strides, padding=padding) + + +@tf.RegisterGradient('NewMaxPool') +def _max_pool_grad(op, grad): + filter_width = op.get_attr('filter_width') + filter_height = op.get_attr('filter_height') + stride_w = op.get_attr('stride_w') + stride_h = op.get_attr('stride_h') + padding = op.get_attr('padding') + return tf.raw_ops.MaxPoolGrad( + orig_input=op.inputs[0], + orig_output=op.outputs[0], + grad=grad, + ksize=[1, filter_width, filter_height, 1], + strides=[1, stride_w, stride_h, 1], + padding=padding, + data_format='NHWC') + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',) + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/BUILD b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD new file mode 100644 index 00000000000..ef08caff939 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD @@ -0,0 +1,45 @@ +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = [ + "//tensorflow/compiler/mlir/tfr/...", + ], +) + +gen_op_libraries( + name = "pad_ops", + src = "ops_defs.py", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "pad_ops_test", + size = "small", + srcs = ["pad_ops_test.py"], + data = [":pad_ops_mlir"], + python_version = "PY3", + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + ":pad_ops", + ":pad_ops_py", + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:test_utils", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py new file mode 100644 index 00000000000..4b072a58f08 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/ops_defs.py @@ -0,0 +1,168 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the mirror pad and mirror pad grad.""" + +# pylint: disable=g-direct-tensorflow-import +# pylint: disable=missing-function-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +Composite = composite.Composite +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'output', None, + 'Path to write the genereated register op file and MLIR file.') + +flags.DEFINE_bool('gen_register_op', True, + 'Generate register op cc file or tfr mlir file.') + + +@Composite( + 'NewMirrorPad', + inputs=['input_: T', 'paddings: Tpaddings'], + attrs=['mode: {"REFLECT", "SYMMETRIC"}'], + derived_attrs=['T: type', 'Tpaddings: {int32, int64} = DT_INT32'], + outputs=['output: T']) +def _composite_mirror_pad(input_, paddings, mode): + shape = input_.shape.as_list() + for i in range(len(shape)): + rdims = tf.raw_ops.OneHot( + indices=i, depth=len(shape), on_value=True, off_value=False, axis=-1) + rarray = tf.raw_ops.Reverse(tensor=input_, dims=rdims) + + left_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 0]) + right_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 1]) + + if mode == 'REFLECT': + left_padding, _ = tf.raw_ops.SplitV( + value=rarray, + size_splits=[left_padding_size, -1], + axis=i, + num_split=2) + _, right_padding = tf.raw_ops.SplitV( + value=rarray, + size_splits=[-1, right_padding_size], + axis=i, + num_split=2) + else: + _, left_padding = tf.raw_ops.SplitV( + value=rarray, + size_splits=[-1, left_padding_size], + axis=i, + num_split=2) + right_padding, _ = tf.raw_ops.SplitV( + value=rarray, + size_splits=[right_padding_size, -1], + axis=i, + num_split=2) + + input_ = tf.raw_ops.Concat( + concat_dim=i, values=[left_padding, input_, right_padding]) + return input_ + + +@tf.RegisterGradient('NewMirrorPad') +def _mirror_pad_grad(op, grad): + mode = op.get_attr('mode') + return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] + + +@Composite( + 'NewMirrorPadGrad', + inputs=['input_: T', 'paddings: Tpaddings'], + attrs=['mode: {"REFLECT", "SYMMETRIC"}'], + derived_attrs=['T: type', 'Tpaddings: {int32, int64} = DT_INT32'], + outputs=['output: T']) +def _composite_mirror_pad_grad(input_, paddings, mode): + shape = input_.shape.as_list() + for i in range(len(shape)): + rdims = tf.raw_ops.OneHot( + indices=i, depth=len(shape), on_value=True, off_value=False, axis=-1) + left_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 0]) + right_padding_size = tf.raw_ops.GatherNd(params=paddings, indices=[i, 1]) + + left_padding, core, right_padding = tf.raw_ops.SplitV( + value=input_, + size_splits=[left_padding_size, -1, right_padding_size], + axis=i, + num_split=3) + reversed_left_padding = tf.raw_ops.Reverse(tensor=left_padding, dims=rdims) + reversed_right_padding = tf.raw_ops.Reverse( + tensor=right_padding, dims=rdims) + zero_like = tf.raw_ops.ZerosLike(x=core) + left_offset, _ = tf.raw_ops.SplitV( + value=zero_like, + size_splits=[-1, left_padding_size], + axis=i, + num_split=2) + right_offset, _ = tf.raw_ops.SplitV( + value=zero_like, + size_splits=[-1, right_padding_size], + axis=i, + num_split=2) + + if mode == 'REFLECT': + from_left_padding = tf.raw_ops.Concat( + concat_dim=i, values=[left_offset, reversed_left_padding]) + from_right_padding = tf.raw_ops.Concat( + concat_dim=i, values=[reversed_right_padding, right_offset]) + else: + from_left_padding = tf.raw_ops.Concat( + concat_dim=i, values=[reversed_left_padding, left_offset]) + from_right_padding = tf.raw_ops.Concat( + concat_dim=i, values=[right_offset, reversed_right_padding]) + input_ = tf.raw_ops.AddN( + inputs=[from_left_padding, core, from_right_padding]) + + return input_ + + +@tf.RegisterGradient('NewMirrorPadGrad') +def _mirror_pad_grad_grad(op, grad): + mode = op.get_attr('mode') + return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] + + +def main(_): + if FLAGS.gen_register_op: + assert FLAGS.output.endswith('.cc') + generated_code = gen_register_op(sys.modules[__name__], '_composite_') + else: + assert FLAGS.output.endswith('.mlir') + generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_') + + dirname = os.path.dirname(FLAGS.output) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(FLAGS.output, 'w') as f: + f.write(generated_code) + + +if __name__ == '__main__': + app.run(main=main) diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py b/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py new file mode 100644 index 00000000000..11f6e0acbf2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/examples/pad/pad_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.examples.pad.ops_defs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow.compiler.mlir.tfr.examples.pad import gen_pad_ops +from tensorflow.compiler.mlir.tfr.examples.pad import ops_defs +from tensorflow.compiler.mlir.tfr.python import test_utils +from tensorflow.python.framework import load_library +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_pad_ops.__file__) +_lib_name = os.path.basename(gen_pad_ops.__file__)[4:].replace('.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class PadOpsDefsTest(test_utils.OpsDefsTest, parameterized.TestCase): + + @parameterized.named_parameters(('ReflectMode', 'REFLECT'), + ('SymmetricMode', 'SYMMETRIC')) + def test_mirror_pad(self, mode): + input_ = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32) + paddings = tf.constant([[ + 1, + 1, + ], [2, 2]]) + kwargs = { + 'input': input_, + 'paddings': paddings, + 'mode': mode, + } + kwargs_ = { + 'input_': input_, + 'paddings': paddings, + 'mode': mode, + } + # Make sure the composition python function is correct + self._assertOpAndComposite([input_], tf.raw_ops.MirrorPad, + ops_defs._composite_mirror_pad, kwargs_, kwargs) + # Make sure the translation and decomposition is correct + self._assertOpAndComposite([input_], + tf.function(gen_pad_ops.new_mirror_pad), + ops_defs._composite_mirror_pad, kwargs_) + + @parameterized.named_parameters(('ReflectMode', 'REFLECT'), + ('SymmetricMode', 'SYMMETRIC')) + def test_mirror_pad_grad(self, mode): + input_ = tf.constant([[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], + [5, 4, 4, 5, 6, 6, 5], [5, 4, 4, 5, 6, 6, 5]], + dtype=tf.float32) + paddings = tf.constant([[ + 1, + 1, + ], [2, 2]]) + kwargs = { + 'input': input_, + 'paddings': paddings, + 'mode': mode, + } + kwargs_ = { + 'input_': input_, + 'paddings': paddings, + 'mode': mode, + } + # Make sure the composition python function is correct + self._assertOpAndComposite([input_], tf.raw_ops.MirrorPadGrad, + ops_defs._composite_mirror_pad_grad, kwargs_, + kwargs) + # Make sure the translation and decomposition is correct + self._assertOpAndComposite([input_], + tf.function(gen_pad_ops.new_mirror_pad_grad), + ops_defs._composite_mirror_pad_grad, kwargs_) + + +if __name__ == '__main__': + os.environ[ + 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/pad' + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index c283b3b6986..7041545637a 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -14,10 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace { + +auto* tf_core_op_expansion_graph_counter = + monitoring::Counter<0>::New("/tensorflow/core/op_expansion/graph_counter", + "The number of graphs being op expanded."); +} // namespace + +namespace tfr { bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const { const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str()); @@ -27,14 +37,20 @@ bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const { Status GraphDecomposePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { if (!IsEnabled(config_proto)) { - VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was " - "not found"; + LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decompositin " + "library was not found"; return Status::OK(); } - VLOG(1) << "Run Graph Decomposition Passes"; - TF_ASSIGN_OR_RETURN(ctx_, TFRDecomposeContext::Get(module.getContext())); - TF_RETURN_IF_ERROR(ctx_->Decompose(module)); - return ctx_->Destroy(); + + tf_core_op_expansion_graph_counter->GetCell()->IncrementBy(1); + + LOG_FIRST_N(INFO, 1) << "Run Graph Decomposition Passes"; + + TF_RETURN_IF_ERROR(DecomposeGraph(module)); + + LOG_FIRST_N(INFO, 1) << "Finish Graph Decomposition Passes"; + + return Status::OK(); } namespace { @@ -45,4 +61,5 @@ static mlir_pass_registration::MlirOptimizationPassRegistration std::make_unique()); } // namespace +} // namespace tfr } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index 589865c0f7d..dd93e99f04b 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace tfr { // An optimization pass that decompose the composite ops in a module according // to the decomposition library. Currently the decomposition library is loaded @@ -37,11 +38,9 @@ class GraphDecomposePass : public MlirOptimizationPass { // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; - - private: - std::unique_ptr ctx_; }; +} // namespace tfr } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py index d11ad9eef81..d573b8e7195 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py @@ -35,14 +35,6 @@ load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) class GraphDecomposeTest(test.TestCase): - def setUp(self): - os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' - super(GraphDecomposeTest, self).setUp() - - def tearDown(self): - del os.environ['TF_MLIR_TFR_LIB_DIR'] - super(GraphDecomposeTest, self).tearDown() - def testAddN(self): add = def_function.function(gen_composite_ops.my_add_n) t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -86,5 +78,6 @@ class GraphDecomposeTest(test.TestCase): if __name__ == '__main__': + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' ops.enable_eager_execution() test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc index 81106690590..5bb7d235fa7 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc @@ -18,16 +18,27 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace { + +auto* tf_core_op_expansion_node_counter = + monitoring::Counter<0>::New("/tensorflow/core/op_expansion/node_counter", + "The number of nodes being op expanded."); +} // namespace + +namespace tfr { Status CompositeOpExpansion::Run(EagerOperation* orig_op, std::unique_ptr* out_op) { if (!IsEnabled()) return Status::OK(); if (orig_op->Device() != kVariantDeviceNull) return Status::OK(); - VLOG(1) << "Run Node Expansion Passes"; + tf_core_op_expansion_node_counter->GetCell()->IncrementBy(1); + + LOG_FIRST_N(INFO, 1) << "Run Node Expansion Passes"; // Get the FunctionDef and insert that into the context const NodeDef& ndef = orig_op->MutableAttrs()->BuildNodeDef(); @@ -40,7 +51,7 @@ Status CompositeOpExpansion::Run(EagerOperation* orig_op, std::string fname = absl::StrCat("_expanded_", ndef.name(), "_", std::to_string(x)); if (!ctx.FindFunctionByName(fname)) { - TF_ASSIGN_OR_RETURN(auto func, TFRDecomposeContext::Expand(ndef, fname)); + TF_ASSIGN_OR_RETURN(auto func, ExpandNode(ndef, fname)); TF_RETURN_IF_ERROR(ctx.AddFunctionDef(func)); } @@ -55,11 +66,14 @@ Status CompositeOpExpansion::Run(EagerOperation* orig_op, new_op->MutableAttrs()->CopyAttributes(orig_op->Attrs()); out_op->reset(new_op); - VLOG(1) << "Rewrite the op to call function: " << fname; + LOG_FIRST_N(INFO, 1) + << "Finish Node Expansion Passes. Rewrite the op to call function: " + << fname; return Status::OK(); } REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, CompositeOpExpansion); +} // namespace tfr } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h index 7e3aa82026b..b1e4911b541 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace tfr { // An optimization pass that decompose the composite ops in a module according // to the decomposition library. Currently the decomposition library is loaded @@ -42,6 +43,7 @@ class CompositeOpExpansion : public EagerOpRewrite { } }; +} // namespace tfr } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py index 6bdaed303ef..f99b52fe65a 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py @@ -34,14 +34,6 @@ load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) class NodeExpansionTest(test.TestCase): - def setUp(self): - os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' - super(NodeExpansionTest, self).setUp() - - def tearDown(self): - del os.environ['TF_MLIR_TFR_LIB_DIR'] - super(NodeExpansionTest, self).tearDown() - def testAddN(self): t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) @@ -81,5 +73,6 @@ class NodeExpansionTest(test.TestCase): if __name__ == '__main__': + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' ops.enable_eager_execution() test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index 91fc40941e0..61e96548579 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace tfr { const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR"; @@ -66,6 +67,10 @@ StatusOr> TFRDecomposeContext::Get( string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir); std::vector files; TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files)); + if (files.empty()) { + return errors::Internal(absl::StrCat( + "Failed to find the decomposition lib from path ", composite_mlir_dir)); + } std::string tfr_raw_text; for (const auto& file : files) { string fullpath = io::JoinPath(composite_mlir_dir, file); @@ -76,7 +81,7 @@ StatusOr> TFRDecomposeContext::Get( } } - auto ctx = TFRDecomposeContext::Get(tfr_raw_text, mlir_ctx); + auto ctx = TFRDecomposeContext::GetFromText(tfr_raw_text, mlir_ctx); if (!ctx) { return errors::Internal(absl::StrCat( "Failed to load the imported decomposition lib: ", tfr_raw_text)); @@ -84,7 +89,7 @@ StatusOr> TFRDecomposeContext::Get( return ctx; } -std::unique_ptr TFRDecomposeContext::Get( +std::unique_ptr TFRDecomposeContext::GetFromText( StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) { mlir_ctx->allowUnregisteredDialects(/*allow=*/true); // Load dialects involved in the conversion @@ -105,20 +110,22 @@ std::unique_ptr TFRDecomposeContext::Get( llvm::SourceMgr source_mgr; source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc()); mlir::OwningModuleRef module = mlir::parseSourceFile(source_mgr, mlir_ctx); + // The MLIRContext owns the module + auto module_op = module.release(); // Create the context - return absl::make_unique(std::move(module)); + return absl::make_unique(module_op); } -StatusOr TFRDecomposeContext::Decompose(const NodeDef& node_def, - StringPiece func_name) { +StatusOr TFRDecomposeContext::ExpandNode(const NodeDef& node_def, + StringPiece func_name) { const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); DataTypeVector input_dtys, output_dtys; TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys)); TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys)); - mlir::MLIRContext* context = tfr_module_->getContext(); + mlir::MLIRContext* context = tfr_module_.getContext(); llvm::SmallVector input_tys, output_tys; mlir::Builder builder(context); for (auto ty : input_dtys) { @@ -159,15 +166,8 @@ StatusOr TFRDecomposeContext::Decompose(const NodeDef& node_def, mlir::Operation* tf_op = op_builder.createOperation(op_state); op_builder.create(loc, tf_op->getResults()); - if (failed(mlir::verify(module))) { - return errors::Internal(absl::StrCat( - "Failed to verify the imported NodeDef: ", node_def.DebugString())); - } - - // Call the decompose passes by using the external symbol table. - if (failed(pm_.run(module))) { - return errors::Internal("Failed to run the decompose passes."); - } + // Run the decompose passes on the module + TF_RETURN_IF_ERROR(DecomposeGraph(module)); // Export the result as a FunctionDef. FunctionDef func_def; @@ -177,7 +177,7 @@ StatusOr TFRDecomposeContext::Decompose(const NodeDef& node_def, return func_def; } -Status TFRDecomposeContext::Decompose(mlir::ModuleOp user_module) { +Status TFRDecomposeContext::DecomposeGraph(mlir::ModuleOp user_module) { // Call the decompose passes by using the external symbol table. if (failed(pm_.run(user_module))) { return errors::Internal("Failed to run the decompose passes."); @@ -185,35 +185,38 @@ Status TFRDecomposeContext::Decompose(mlir::ModuleOp user_module) { return Status::OK(); } -StatusOr TFRDecomposeContext::Expand(const NodeDef& node_def, - StringPiece func_name) { - mlir::MLIRContext mlir_ctx; - mlir_ctx.allowUnregisteredDialects(/*allow=*/true); - TF_ASSIGN_OR_RETURN(auto ctx, Get(&mlir_ctx)); - return ctx->Decompose(node_def, func_name); -} - -Status TFRDecomposeContext::Destroy() { - tfr_module_.release().erase(); - return Status::OK(); -} - // Constructor of the decompose context. -TFRDecomposeContext::TFRDecomposeContext(mlir::OwningModuleRef tfr_module) - : tfr_module_(std::move(tfr_module)), pm_(tfr_module_->getContext()) { +TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module) + : tfr_module_(tfr_module), pm_(tfr_module_.getContext()) { mlir::OpPassManager& func_pm = pm_.nest(); // Prepare the imported graph. func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass()); // Run TFR lowering, inlining and raising to tf. - func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_.get())); + func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_)); func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass( - tfr_module_.get(), /*materialize_derived_attrs=*/true)); + tfr_module_, /*materialize_derived_attrs=*/true)); // Prepare to be exported. func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass()); pm_.addPass(mlir::CreateBreakUpIslandsPass()); } +void TFRDecomposeContext::Destroy() { tfr_module_.erase(); } + +StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name) { + mlir::MLIRContext mlir_ctx; + TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx)); + return ctx->ExpandNode(node_def, func_name); +} + +Status DecomposeGraph(mlir::ModuleOp user_module) { + mlir::MLIRContext* mlir_ctx = user_module.getContext(); + TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(mlir_ctx)); + return ctx->DecomposeGraph(user_module); +} + +} // namespace tfr } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h index 548c3ab8c0c..6e33bbf0b0c 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { +namespace tfr { extern const char* const kTFRLibEnv; using stream_executor::port::StatusOr; -using NodeAndType = std::pair; // An wrapper for all the objects used to decompose a module (graph mode) and // node_def (eager mode). Note that this class owns the decomposition library. @@ -38,38 +38,44 @@ class TFRDecomposeContext { static StatusOr> Get( mlir::MLIRContext* mlir_ctx); - static std::unique_ptr Get(StringPiece tfr_raw_text, - mlir::MLIRContext* mlir_ctx); - - // Decompose the NodeDef to a set of primitive ops according to the decompose - // library loaded. Wrap the decomposed result in a FunctionDef. - static StatusOr Expand(const NodeDef& node_def, - StringPiece func_name); - // Constructor of the decompose context. To share the decompose library, the // whole decompose TFR function library is loaded. - explicit TFRDecomposeContext(mlir::OwningModuleRef tfr_module); + explicit TFRDecomposeContext(mlir::ModuleOp tfr_module); - // Decompose the op in the NodeDef to a set of primitive ops according to the + // Constructs the decompose context from the tfr text module and the mlir + // context. The tfr text module is added to the mlir context. + static std::unique_ptr GetFromText( + StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx); + + // Decomposes the op in the NodeDef to a set of primitive ops according to the // decompose library in the context. Wrap the decomposed result in a // FunctionDef. - StatusOr Decompose(const NodeDef& node_def, - StringPiece func_name); + StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); - // Decompose the ops in the ModuleOp to a set of primitive ops according to - // decompose library in the context. - Status Decompose(mlir::ModuleOp user_module); + // Runs the decompose passes on the user_module. + Status DecomposeGraph(mlir::ModuleOp user_module); - // Release all the owned references. - Status Destroy(); + // Erases the tfr_module created. + void Destroy(); private: - mlir::OwningModuleRef tfr_module_; + mlir::ModuleOp tfr_module_; mlir::PassManager pm_; GraphExportConfig export_confs_; }; +// Decomposes the NodeDef to a set of primitive ops according to the decompose +// library loaded. Wrap the decomposed result in a FunctionDef. +StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); + +// Decomposes the ops in the ModuleOp to a set of primitive ops according to +// decompose library in the context. +Status DecomposeGraph(mlir::ModuleOp user_module); + +} // namespace tfr } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc index 614e46df15c..3d83b8d5535 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -36,6 +36,7 @@ limitations under the License. using testing::ElementsAreArray; using testing::Test; +using NodeAndType = std::pair; namespace tensorflow { @@ -88,11 +89,13 @@ tfr.func @tf__risc_add_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attrib class TFRDecomposeContextTest : public Test { protected: void SetUp() override { - test_ctx_ = TFRDecomposeContext::Get(tfr_raw_text, &ctx_); + test_ctx_ = tfr::TFRDecomposeContext::GetFromText(tfr_raw_text, &ctx_); } + void TearDown() override { test_ctx_->Destroy(); } + mlir::MLIRContext ctx_; - std::unique_ptr test_ctx_; + std::unique_ptr test_ctx_; }; std::vector NodesSequenceOf(const FunctionDef& graph) { @@ -111,7 +114,7 @@ TEST_F(TFRDecomposeContextTest, FLOAT_1_ins) { .Input(src_list) .Finalize(&test_node); EXPECT_TRUE(status.ok()); - auto decomposed = test_ctx_->Decompose(test_node, "test"); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); EXPECT_TRUE(decomposed.ok()); std::vector expected_results{{"Identity", DT_FLOAT}}; EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), @@ -128,7 +131,7 @@ TEST_F(TFRDecomposeContextTest, FLOAT_3_ins) { .Input(src_list) .Finalize(&test_node); EXPECT_TRUE(status.ok()); - auto decomposed = test_ctx_->Decompose(test_node, "test"); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); EXPECT_TRUE(decomposed.ok()); std::vector expected_results{{"RiscAdd", DT_FLOAT}, @@ -146,7 +149,7 @@ TEST_F(TFRDecomposeContextTest, INT32_3_ins) { auto status = NodeDefBuilder("int_add", "MyAddN").Input(src_list).Finalize(&test_node); EXPECT_TRUE(status.ok()); - auto decomposed = test_ctx_->Decompose(test_node, "test"); + auto decomposed = test_ctx_->ExpandNode(test_node, "test"); EXPECT_TRUE(decomposed.ok()); std::vector expected_results{{"RiscAdd", DT_INT32}, diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 9265437cca9..e1a9ae8c2e6 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" @@ -50,6 +51,21 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" #include "tensorflow/compiler/mlir/tfr/passes/passes.h" #include "tensorflow/compiler/mlir/tfr/utils/utils.h" +#include "tensorflow/core/lib/monitoring/counter.h" + +namespace tensorflow { +namespace { + +auto* tf_core_op_expansion_op_counter = + monitoring::Counter<1>::New("/tensorflow/core/op_expansion/op_counter", + "The number of composite op expanded.", "name"); +} + +void IncreaseOpExpansionExecuteCounterByOne(const std::string& op_name) { + tf_core_op_expansion_op_counter->GetCell(op_name)->IncrementBy(1); +} + +} // namespace tensorflow //===----------------------------------------------------------------------===// // The pass to decompose unregistered TF ops with the TFR compose function. @@ -62,7 +78,6 @@ namespace { // Decompose the TF ops with the registered composition library. struct DecomposeTFOpsPass : public PassWrapper { - explicit DecomposeTFOpsPass(llvm::Optional external_tfr_module) : external_tfr_module(external_tfr_module) {} @@ -118,6 +133,9 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { return; } + tensorflow::IncreaseOpExpansionExecuteCounterByOne( + op->getName().getStringRef().str()); + auto compose_func_type = compose_func.getType(); builder.setInsertionPoint(op); TFRTensorType unconstrainted_tensor_type = builder.getType(); diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py index 4876e17790b..99b2dfdedc4 100644 --- a/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py @@ -139,8 +139,7 @@ def gen_register_op(source, method_prefix=None): if not method_prefix or name.startswith(method_prefix) ] headers = r""" -#include "third_party/tensorflow/core/framework/node_def_builder.h" -#include "third_party/tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op.h" namespace tensorflow { """ diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py index 807e48a3dbf..6392015ba4d 100644 --- a/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py @@ -55,7 +55,6 @@ class TFRGenTensorTest(test.TestCase): def test_op_reg_gen(self): cxx_code = gen_register_op(sys.modules[__name__]) cxx_code_exp = r""" - CHECK: #include "third_party/tensorflow/core/framework/node_def_builder.h" CHECK-NEXT: #include "third_party/tensorflow/core/framework/op.h" CHECK-EMPTY CHECK-LABEL: namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/python/test_utils.py b/tensorflow/compiler/mlir/tfr/python/test_utils.py new file mode 100644 index 00000000000..62aa3e39105 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/test_utils.py @@ -0,0 +1,48 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utils for composite op definition.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.platform import test + + +class OpsDefsTest(test.TestCase): + """Test utils.""" + + def _assertOpAndComposite(self, vars_, compute_op, compute_composite, kwargs, + op_kwargs=None): + if op_kwargs is None: + op_kwargs = kwargs + + # compute with op. + with backprop.GradientTape() as gt: + for var_ in vars_: + gt.watch(var_) + y = compute_op(**op_kwargs) # uses op and decomposites by the graph pass. + grads = gt.gradient(y, vars_) # uses registered gradient function. + + # compute with composition + with backprop.GradientTape() as gt: + for var_ in vars_: + gt.watch(var_) + re_y = compute_composite(**kwargs) # uses composite function. + re_grads = gt.gradient(re_y, vars_) # uses gradients compposite function. + + for v, re_v in zip(y, re_y): + self.assertAllClose(v, re_v) + for g, re_g in zip(grads, re_grads): + self.assertAllClose(g, re_g) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py index f8622d11511..3bf89c7a2d5 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py @@ -45,6 +45,7 @@ from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs from tensorflow.python.autograph.pyct.static_analysis import type_inference from tensorflow.python.framework import load_library from tensorflow.python.framework import op_def_registry +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -1339,20 +1340,33 @@ def tfr_gen(func, op_defs): def tfr_gen_from_module(source, method_prefix=None, op_libraries=None): - """Parse a python code and emit the TFR functions from a target class.""" + """Parse the input source module and emit the TFR functions.""" op_defs = OpDefCache() + # Load the op library so the op is added to the op registry. This is + # required when the op cc_library couldn't be statically linked in open + # source. + # This is a no op if the op shared library couldn't be found in the same + # directory of the op Python API. + # TODO(fengliuai): make the .so file path configurable. if op_libraries: + prefix_len = len('gen_') for m in op_libraries: lib_dir = os.path.dirname(m.__file__) - prefix_len = len('gen_') lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') - # Load the op library so the op is added to the op registry. This is - # required when the op cc_library couldn't be statically linked in open - # source. - # This is a no op if the op shared library couldn't be found in the same - # directory of the op Python API. - load_library.load_op_library(os.path.join(lib_dir, lib_name)) + lib_path = os.path.join(lib_dir, lib_name) + if os.path.exists(lib_path): + logging.info('load file: ' + lib_path) + load_library.load_op_library(lib_path) + else: + # The op library is generated from the source module, then we load all the + # .so file in the directory + lib_dir = os.path.dirname(source.__file__) + for lib_name in os.listdir(lib_dir): + if lib_name.endswith('.so'): + lib_path = os.path.join(lib_dir, lib_name) + logging.info('load file: ' + lib_path) + load_library.load_op_library(lib_path) mlir_funcs = [ tfr_gen(func, op_defs) diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD index 62ca65c5b57..bb3f07d3e7c 100644 --- a/tensorflow/compiler/mlir/tfr/resources/BUILD +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -1,5 +1,4 @@ -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_bindings") package( default_visibility = [ @@ -22,76 +21,6 @@ filegroup( srcs = ["decomposition_lib.mlir"], ) -cc_library( - name = "composite_ops_cc", - srcs = ["composite_ops.cc"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) +gen_op_bindings(name = "composite") -tf_custom_op_library( - name = "composite_ops.so", - srcs = [ - "composite_ops.cc", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_composite_ops", - out = "gen_composite_ops.py", - deps = [ - ":composite_ops_cc", - ], -) - -tf_custom_op_py_library( - name = "composite_ops", - dso = [":composite_ops.so"], - kernels = [":composite_ops_cc"], - visibility = ["//visibility:public"], - deps = [ - ":gen_composite_ops", - ], -) - -cc_library( - name = "test_ops_cc", - srcs = ["test_ops.cc"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) - -tf_custom_op_library( - name = "test_ops.so", - srcs = [ - "test_ops.cc", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_test_ops", - out = "gen_test_ops.py", - deps = [ - ":test_ops_cc", - ], -) - -tf_custom_op_py_library( - name = "test_ops", - dso = ["test_ops.so"], - kernels = [ - ":test_ops_cc", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_test_ops", - ], -) +gen_op_bindings(name = "test") diff --git a/tensorflow/compiler/mlir/tfr/tfr.bzl b/tensorflow/compiler/mlir/tfr/tfr.bzl deleted file mode 100644 index cc1b617f932..00000000000 --- a/tensorflow/compiler/mlir/tfr/tfr.bzl +++ /dev/null @@ -1,120 +0,0 @@ -"""BUILD extension for TF composition project.""" - -load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.google.bzl", "pytype_library") - -def gen_op_libraries( - name, - src, - deps, - tags = [], - test = False): - """gen_op_libraries() generates all cc and py libraries for composite op source. - - Args: - name: used as the name component of all the generated libraries. - src: File contains the composite ops. - deps: Libraries the 'src' depends on. - tags: - test: - """ - if not src.endswith(".py") or name == src[:-3]: - fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) - - gen_op_lib_exec = src[:-3] - py_binary( - name = gen_op_lib_exec, - srcs = [src], - srcs_version = "PY2AND3", - python_version = "PY3", - deps = [ - "//tensorflow/python:platform", - ] + deps, - ) - - register_op = "register_" + name - native.genrule( - name = register_op, - srcs = [], - outs = [name + ".inc.cc"], - cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, - exec_tools = [":" + gen_op_lib_exec], - local = 1, - tags = tags, - ) - - native.cc_library( - name = name + "_cc", - testonly = test, - srcs = [":" + register_op], - copts = [ - "-Wno-unused-result", - "-Wno-unused-variable", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, - ) - - tf_gen_op_wrapper_py( - name = name, - out = name + ".py", - deps = [ - ":%s_cc" % name, - ], - ) - - pytype_library( - name = name + "_grads", - srcs = [ - src, - ], - srcs_version = "PY2AND3", - deps = [ - "//third_party/py/numpy", - "//third_party/py/tensorflow", - ] + deps, - ) - - pytype_library( - name = name + "_lib", - srcs = [ - name + ".py", - ], - srcs_version = "PY2AND3", - deps = [ - ":%s" % name, - ":%s_cc" % name, - ":%s_grads" % name, - "//third_party/py/numpy", - "//third_party/py/tensorflow", - ] + deps, - ) - - # Link the register op and rebuild the binary - gen_tfr_lib_exec = gen_op_lib_exec + "_registered" - py_binary( - name = gen_tfr_lib_exec, - main = src, - srcs = [src], - srcs_version = "PY2AND3", - python_version = "PY3", - deps = [ - "//tensorflow/python:platform", - ":%s" % name + "_cc", - ] + deps, - ) - - op_tfr = "composite_" + name - native.genrule( - name = op_tfr, - srcs = [], - outs = [name + ".mlir"], - cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, - exec_tools = [":" + gen_tfr_lib_exec], - local = 1, - tags = tags, - ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index b4cbf765c79..2e402f2be22 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -77,6 +77,8 @@ cc_library( "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFToStandard", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -85,7 +87,10 @@ cc_library( tf_cc_binary( name = "tf_to_gpu_binary", - srcs = ["tf_to_gpu_binary.cc"], + srcs = [ + "crash_handler.h", + "tf_to_gpu_binary.cc", + ], visibility = [ "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary:__pkg__", "//tensorflow/core/kernels/mlir_generated:__pkg__", @@ -95,6 +100,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", + "//tensorflow/core/platform", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h b/tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h new file mode 100644 index 00000000000..9ecaa40e567 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h @@ -0,0 +1,38 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CRASH_HANDLER_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CRASH_HANDLER_H_ + +#include "llvm/Support/PrettyStackTrace.h" +#include "tensorflow/core/platform/platform.h" + +namespace tensorflow { +namespace kernel_gen { + +inline void SetCrashReportMessage() { +#if defined(PLATFORM_GOOGLE) + llvm::setBugReportMsg( + "The TensorFlow Kernel Generator crashed, see the docs at " + "go/tf-kernel-gen for debug hints and contact information.\n"); +#else + llvm::setBugReportMsg( + "The TensorFlow Kernel Generator crashed, please report a bug with the " + "trace below on https://github.com/tensorflow/tensorflow/issues.\n"); +#endif +} +} // namespace kernel_gen +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CRASH_HANDLER_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index b3d92773be4..676e1849318 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) { } //===----------------------------------------------------------------------===// -// AllocRawOp +// TFAllocOp //===----------------------------------------------------------------------===// template <> -LogicalResult Verify(AllocRawOp op) { +LogicalResult Verify(TFAllocOp op) { // Check that the total number of operands matches the number of dynamic // dimensions specified in the memref type. unsigned result_dyn_dims = op.getType().getNumDynamicDims(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index e6e29bcbdc2..2f3e0f6f5fa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -49,21 +49,28 @@ class TFFramework_Op traits = []> : } //===----------------------------------------------------------------------===// -// AllocRawOp +// TFAllocOp //===----------------------------------------------------------------------===// -def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", +def TFFramework_TFAllocOp : TFFramework_Op<"alloc", [MemoryEffects<[MemAlloc]>]> { let summary = "allocation of tensors that uses TF Framework"; let description = [{ Allocation of tensors during kernel execution in the Compute method. - This should be used to allocate any temporary or output memref. - Corresponds to `Allocator::AllocateRaw` in - tensorflow/core/framework/allocator.h. + This should be used to allocate any temporary or output memref. If + `output_index` and `input_indices` are given, attempts to forward one of + the input tensors to the output by calling `OpKernelContext::forward_input`. + + If the attributes are missing or the forwarding fails, calls + `Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h. }]; - let arguments = (ins TFFramework_OpKernelContextType:$ctx, - Variadic:$dyn_sizes); + let arguments = (ins + TFFramework_OpKernelContextType:$ctx, + Variadic:$dyn_sizes, + OptionalAttr:$input_indices, + OptionalAttr:$output_index + ); let results = (outs Res]>:$result); let builders = [ @@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", } //===----------------------------------------------------------------------===// -// DeallocRawOp +// TFDeallocOp //===----------------------------------------------------------------------===// -def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw", +def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc", [MemoryEffects<[MemFree]>]> { let summary = "deallocation of tensors that uses TF Framework"; let description = [{ Deallocation of tensors during kernel execution in the Compute method. This should be used to deallocate any temporary memref that was allocated - with `tf_framework.alloc_raw`. + with `tf_framework.alloc`. Corresponds to `Allocator::DeallocateRaw` in tensorflow/core/framework/allocator.h. }]; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 6c95323ed37..5a0fa9e2296 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project @@ -36,6 +37,7 @@ limitations under the License. #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project +#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -81,15 +83,19 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(mlir::mhlo::createLegalizeToLhloPass( /*results_escape_functions=*/true)); // Moving `AllocOp`s and inserting missing `DeallocOp`s - pm.addPass(::mlir::createBufferPlacementPass()); + pm.addPass(::mlir::createBufferHoistingPass()); + pm.addPass(::mlir::createBufferDeallocationPass()); pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); } else { pm.addPass(mlir::mhlo::createLegalizeTFPass( /*allow_partial_conversion=*/false, /*legalize_chlo=*/false)); - pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); pm.addPass(mlir::createTransformUnrankedHloPass()); + pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); + // Clean up the IR created above. In particular, operations on descriptors + // are simplified here. + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass()); pm.addPass(mlir::kernel_gen::transforms::CreateParallelLoopsToSequential()); } @@ -166,6 +172,11 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass()); } pm.addPass(::mlir::createLowerAffinePass()); + + // Constraints are removed as late as possible and before lowering to CFG. + pm.addPass(::mlir::createConvertShapeConstraintsPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addPass(::mlir::createLowerToCFGPass()); if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir index bb0f1926cda..5d0beb7c7fe 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir @@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index dealloc %buf : memref std.return %size_0 : index } -// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw +// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc // CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref -// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref +// CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref // CHECK-NEXT: return [[SIZE_0]] : index // ----- diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir index 1d1b3319515..1d3d5e485fb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir @@ -2,6 +2,6 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) { // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}} - %buf = tf_framework.alloc_raw(%ctx, %size) : memref + %buf = tf_framework.alloc(%ctx, %size) : memref return } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir index fc8e7c97ec8..aa291c4c439 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -4,17 +4,28 @@ // Verify the generic form can be parsed. // RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s -// CHECK-LABEL: func @alloc_raw -func @alloc_raw(%ctx: !tf_framework.op_kernel_context, +// CHECK-LABEL: func @alloc +func @alloc(%ctx: !tf_framework.op_kernel_context, %size_0 : index , %size_2 : index) { - %buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8> - %buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref + %buf_0 = tf_framework.alloc(%ctx) : memref<10xi8> + %buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref return } -// CHECK-LABEL: func @dealloc_raw -func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref) { - tf_framework.dealloc_raw(%ctx, %memref) : memref +// CHECK-LABEL: func @forwarding_alloc +func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context, + %size_0 : index , %size_2 : index) { + %buf = tf_framework.alloc(%ctx, %size_0, %size_2) { + input_indices = [0 : i32, 1 : i32], + output_index = 0 : i32 + } : memref + return +} + +// CHECK-LABEL: func @dealloc +func @dealloc(%ctx: !tf_framework.op_kernel_context, + %memref : memref) { + tf_framework.dealloc(%ctx, %memref) : memref return } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index b943321e95b..8530e4eccde 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -1,21 +1,21 @@ // RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s -// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw -// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr +// CHECK: llvm.func @_mlir_ciface_tf_alloc +// CHECK-SAME: (!llvm.ptr, !llvm.i64, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr) -> !llvm.ptr -// CHECK-LABEL: llvm.func @alloc_raw( +// CHECK-LABEL: llvm.func @alloc( // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, // CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64, // CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] { -func @alloc_raw(%ctx: !tf_framework.op_kernel_context, +func @alloc(%ctx: !tf_framework.op_kernel_context, %size_0 : index , %size_2 : index) -> memref { - %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref + %buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref std.return %buf : memref } // Compute number of elements. // CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 // CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1]] : !llvm.i64 -// CHECK: [[NUM_ELEM_1:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : !llvm.i64 +// CHECK: [[NUM_ELEMS:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : !llvm.i64 // Compute the size of an individual element. // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr @@ -25,10 +25,15 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] // CHECK-SAME: !llvm.ptr to !llvm.i64 +// Compute output index (-1) and candidate indices (0, NULL). +// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32 +// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr + // Allocate memory. -// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]] -// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]]) -// CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr +// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_ELEMS]], +// CHECK-SAME: [[SIZE_OF_FLOAT]], [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], +// CHECK-SAME: [[CANDIDATES_PTR]]) // Build memref descriptor. // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] @@ -55,13 +60,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, // ----- -// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr, !llvm.ptr) +// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr, !llvm.ptr) -// CHECK-LABEL: llvm.func @dealloc_raw( +// CHECK-LABEL: llvm.func @dealloc( // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, -func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, +func @dealloc(%ctx: !tf_framework.op_kernel_context, %memref : memref) { - tf_framework.dealloc_raw(%ctx, %memref) : memref + tf_framework.dealloc(%ctx, %memref) : memref return } // Extract allocated ptr from the memref descriptor. @@ -71,5 +76,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, // CHECK-SAME: !llvm.ptr to !llvm.ptr // Deallocate. -// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw( +// CHECK: llvm.call @_mlir_ciface_tf_dealloc( // CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr, !llvm.ptr) -> () diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index e75db59d885..11dc473f691 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -24,23 +24,50 @@ namespace tf_framework { namespace { using tensorflow::Allocator; +using tensorflow::AllocatorAttributes; Allocator* GetAllocator(void* op_kernel_ctx) { auto* ctx = static_cast(op_kernel_ctx); // TODO(pifon): Figure out how to set AllocatorAttributes correctly. - tensorflow::AllocatorAttributes attrs; + AllocatorAttributes attrs; return ctx->get_allocator(attrs); } } // namespace -extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx, - size_t num_bytes) { +extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_elements, + size_t element_size, + int32_t output_index, + int32_t num_candidates, + int32_t* candidate_input_indices) { + auto* ctx = static_cast(op_kernel_ctx); + if (output_index != -1) { + // Create a 1D shape, because the shapes don't have to match exactly for + // input forwarding. Only the number of elements must be the same. + tensorflow::TensorShape output_shape; + output_shape.AddDim(num_elements); + + // Iterate over indices of all inputs that can potentially be used for + // forwarding. + for (int i = 0; i < num_candidates; ++i) { + // TODO(pifon): Expose fetching AllocatorAttributes with the output_index. + AllocatorAttributes output_attr; + auto tensor = ctx->forward_input( + candidate_input_indices[i], output_index, + ctx->expected_output_dtype(output_index), output_shape, + ctx->output_memory_type(output_index), output_attr); + if (tensor != nullptr) { + return tensor->data(); + } + } + } + // If no forwarding happened, allocate a chunk of memory. return GetAllocator(op_kernel_ctx) - ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes); + ->AllocateRaw(Allocator::kAllocatorAlignment, + num_elements * element_size); } -extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) { +extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) { GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h index 143ebc95932..607cd38a1aa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -22,10 +22,12 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw( - void* op_kernel_ctx, size_t num_bytes); +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc( + void* op_kernel_ctx, size_t num_elements, size_t element_size, + int32_t output_index, int32_t num_candidates, + int32_t* candidate_input_indices); -extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw( +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc( void* op_kernel_ctx, void* ptr); } // namespace tf_framework diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index 84c2bf46b55..7ecae51c194 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -25,6 +25,7 @@ #include "llvm/Support/CommandLine.h" #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/crash_handler.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -63,6 +64,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, } // namespace tensorflow int main(int argc, char** argv) { + tensorflow::kernel_gen::SetCrashReportMessage(); llvm::cl::opt input_file("input", llvm::cl::desc("input file"), llvm::cl::value_desc("filename"), llvm::cl::init("foo.mlir")); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index 4cbe2e633f8..87c8e57804b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -62,7 +62,7 @@ std::unique_ptr GetTargetMachine(llvm::Module* module) { } llvm::TargetOptions target_options = - llvm::codegen::InitTargetOptionsFromCodeGenFlags(); + llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple()); return std::unique_ptr(target->createTargetMachine( triple.str(), "generic", "", target_options, llvm::Reloc::Model::PIC_)); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index f6a47cca6d0..b9ff5cdb287 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -40,10 +40,9 @@ namespace transforms { namespace { class TensorFromElementsOpConverter - : public BufferAssignmentOpConversionPattern { + : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - TensorFromElementsOp>::BufferAssignmentOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( TensorFromElementsOp op, ArrayRef operands, @@ -64,10 +63,9 @@ class TensorFromElementsOpConverter }; class DynamicTensorFromElementsOpConverter - : public BufferAssignmentOpConversionPattern { + : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - DynamicTensorFromElementsOp>::BufferAssignmentOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( DynamicTensorFromElementsOp op, ArrayRef operands, @@ -115,11 +113,9 @@ class DynamicTensorFromElementsOpConverter } }; -class TensorLoadOpConversion - : public BufferAssignmentOpConversionPattern { +class TensorLoadOpConversion : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - TensorLoadOp>::BufferAssignmentOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( TensorLoadOp op, ArrayRef operands, @@ -131,10 +127,9 @@ class TensorLoadOpConversion }; class ExtractElementOpConversion - : public BufferAssignmentOpConversionPattern { + : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - ExtractElementOp>::BufferAssignmentOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( ExtractElementOp op, ArrayRef operands, @@ -152,27 +147,22 @@ class ExtractElementOpConversion }; template -class SimpleOpResultConversion - : public BufferAssignmentOpConversionPattern { +class SimpleOpResultConversion : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - OpTy>::BufferAssignmentOpConversionPattern; - using BufferAssignmentOpConversionPattern::converter; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, converter.convertType(op.getType()), - operands); + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), operands); return success(); } }; -class TensorCastOpConverter - : public BufferAssignmentOpConversionPattern { +class TensorCastOpConverter : public OpConversionPattern { public: - using BufferAssignmentOpConversionPattern< - TensorCastOp>::BufferAssignmentOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( TensorCastOp op, ArrayRef operands, @@ -180,7 +170,7 @@ class TensorCastOpConverter Value arg = operands.front(); if (!arg.getType().isa()) return failure(); - auto result_ty = converter.convertType(op.getType()); + auto result_ty = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, arg, result_ty); return success(); @@ -190,12 +180,12 @@ class TensorCastOpConverter } // namespace void populateStandardBufferizePattern(MLIRContext *context, - BufferAssignmentTypeConverter *converter, + BufferizeTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert, TensorLoadOpConversion, - TensorCastOpConverter>(context, *converter); + TensorCastOpConverter>(*converter, context); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 1ddc44123fa..9a531515012 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -86,7 +86,7 @@ struct BufferizePass : public BufferizePassBase { return !op.tensor().getType().isa(); }); - BufferAssignmentTypeConverter converter; + BufferizeTypeConverter converter; auto typesAreLegal = [&converter](Operation* op) { return converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()); @@ -102,8 +102,8 @@ struct BufferizePass : public BufferizePassBase { OwningRewritePatternList patterns; mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); - populateWithBufferAssignmentOpConversionPatterns( + populateWithBufferizeOpConversionPatterns( &context, converter, patterns); populateStandardBufferizePattern(&context, &converter, &patterns); populateShapeTypeConversionPatterns(&context, converter, patterns); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc index aa02aefa9d2..3b006c954cf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -75,19 +75,19 @@ class AllocOpConverter : public OpConversionPattern { return failure(); } // Symbolic operands that bind to the symbols of the memref's layout map are - // not supported by AllocRawOp. + // not supported by TFAllocOp. if (alloc.getNumSymbolicOperands() != 0) { return failure(); } - rewriter.replaceOpWithNewOp(alloc, alloc.getType(), ctx, - operands); + rewriter.replaceOpWithNewOp(alloc, alloc.getType(), ctx, + operands); return success(); } }; // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType // arg of the parent function. -class DeallocOpConverter : public OpConversionPattern { +class TFDeallocOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -108,8 +108,8 @@ class DeallocOpConverter : public OpConversionPattern { return failure(); } DeallocOp::Adaptor transformed(operands); - rewriter.replaceOpWithNewOp(dealloc, ctx, - transformed.memref()); + rewriter.replaceOpWithNewOp(dealloc, ctx, + transformed.memref()); return success(); } }; @@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern { void PopulateEmbedTFFrameworkConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert( + patterns->insert( context); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h index 0f2a41b3de6..32fc375d48b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h @@ -20,8 +20,7 @@ limitations under the License. namespace mlir { -class BufferAssignmentPlacer; -class BufferAssignmentTypeConverter; +class BufferizeTypeConverter; class LLVMTypeConverter; class MLIRContext; class OwningRewritePatternList; @@ -45,7 +44,7 @@ namespace transforms { /// Collects a set of patterns that bufferize operations from the standard /// dialect. void populateStandardBufferizePattern(MLIRContext *context, - BufferAssignmentTypeConverter *converter, + BufferizeTypeConverter *converter, OwningRewritePatternList *patterns); } // namespace transforms } // namespace kernel_gen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 431919c2de7..cc9884f97ee 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -30,8 +31,8 @@ namespace { using LLVM::LLVMFuncOp; using LLVM::LLVMType; -static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw"; -static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw"; +static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc"; +static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc"; /// Base class for patterns converting TF Framework ops to function calls. template @@ -60,27 +61,42 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { virtual LLVMType GetFuncType() const = 0; }; -class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { +class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { public: - using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; + using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; LogicalResult matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - AllocRawOp alloc_raw_op = cast(op); - AllocRawOp::Adaptor transformed(operands); + TFAllocOp tf_alloc_op = cast(op); + TFAllocOp::Adaptor transformed(operands); - MemRefType memref_type = alloc_raw_op.getType(); + MemRefType memref_type = tf_alloc_op.getType(); // Get memref descriptor sizes. SmallVector sizes; getMemRefDescriptorSizes(loc, memref_type, llvm::to_vector<4>(transformed.dyn_sizes()), rewriter, sizes); - // Get memory block size in bytes. - Value num_bytes = getCumulativeSizeInBytes( - loc, memref_type.getElementType(), sizes, rewriter); + // Get number of elements. + Value num_elements = getNumElements(loc, sizes, rewriter); + // Get element size. + Value element_size = + getSizeInBytes(loc, memref_type.getElementType(), rewriter); + + // Convert `output_index` or set it to -1 if the attribute is missing. + LLVM::LLVMType llvmInt32Type = + LLVM::LLVMType::getInt32Ty(rewriter.getContext()); + Value output_index = rewriter.create( + loc, llvmInt32Type, + rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue() + ? tf_alloc_op.output_index().getValue() + : -1)); + + // Convert `candidate_input_indices`. + auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray( + loc, tf_alloc_op.input_indices(), &rewriter); // Insert function call. FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op); @@ -88,7 +104,10 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { rewriter .create( loc, getVoidPtrType(), tf_func_ref, - llvm::makeArrayRef({transformed.ctx(), num_bytes})) + llvm::makeArrayRef({transformed.ctx(), num_elements, + element_size, output_index, + candidates_count_and_ptr.first, + candidates_count_and_ptr.second})) .getResult(0); MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( @@ -103,10 +122,19 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { StringRef GetFuncName() const override { return kCInterfaceAlloc; } LLVMType GetFuncType() const override { + LLVMType llvm_i32_type = + LLVM::LLVMType::getInt32Ty(getDialect().getContext()); + LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo(); LLVMType llvm_void_ptr_type = getVoidPtrType(); - return LLVM::LLVMType::getFunctionTy( + return LLVMType::getFunctionTy( llvm_void_ptr_type, - llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}), + llvm::makeArrayRef( + {/*void* op_kernel_ctx*/ llvm_void_ptr_type, + /*size_t num_elements*/ getIndexType(), + /*size_t element_size*/ getIndexType(), + /*int32_t output_index*/ llvm_i32_type, + /*int32_t num_candidates*/ llvm_i32_type, + /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}), /*isVarArg=*/false); } @@ -144,16 +172,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { } return memref_desc; } + + std::pair ConvertI32ArrayAttrToStackAllocatedArray( + Location loc, llvm::Optional attr, + ConversionPatternRewriter *rewriter) const { + LLVMType llvm_i32_type = + LLVM::LLVMType::getInt32Ty(getDialect().getContext()); + LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo(); + + // If the attribute is missing or empty, set the element count to 0 and + // return NULL. + if (!attr.hasValue() || attr.getValue().empty()) { + Value zero = rewriter->create( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(0)); + Value null_ptr = rewriter->create(loc, llvm_i32_ptr_type); + return std::make_pair(zero, null_ptr); + } + + // Allocate array to store the elements. + auto &array_attr = attr.getValue(); + Value array_size = rewriter->create( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size())); + Value array_ptr = rewriter->create( + loc, llvm_i32_ptr_type, array_size, /*alignment=*/0); + + for (auto &dim : llvm::enumerate(array_attr)) { + Value index = rewriter->create( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index())); + Value elem_ptr = rewriter->create(loc, llvm_i32_ptr_type, + array_ptr, index); + Value elem = rewriter->create( + loc, llvm_i32_type, + rewriter->getI32IntegerAttr( + dim.value().cast().getInt())); + rewriter->create(loc, elem, elem_ptr); + } + return std::make_pair(array_size, array_ptr); + } }; -class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern { +class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern { public: - using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; + using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; LogicalResult matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - DeallocRawOp::Adaptor transformed(operands); + TFDeallocOp::Adaptor transformed(operands); MemRefDescriptor memref(transformed.memref()); Value allocated_bytes_ptr = rewriter.create( @@ -194,7 +259,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern { void PopulateTFFrameworkToLLVMConversionPatterns( LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert(*converter); - patterns->insert(*converter); + patterns->insert(*converter); } } // namespace tf_framework diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 1919446a365..32dd1e202ee 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -428,8 +428,8 @@ tf_cc_binary( name = "xla-opt", deps = [ ":all_xla_passes_for_testing", - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/mlir:tf_mlir_opt_main", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", ], ) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 253156b44a5..d682b6cb44b 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -806,6 +806,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions // in physical minor-to-major order. if (instruction->shape().IsArray() && + !instruction->shape().layout().minor_to_major().empty() && instruction->shape().layout() != LayoutUtil::MakeDescendingLayout( instruction->shape().dimensions().size())) { diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index ccfcebab60e..875f521f520 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -1012,13 +1012,24 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { &comparator))) return failure(); - auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator, + auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator, op.dimension(), op.is_stable()); auto& value_map = *ctx.values; + auto shape_or = sorted.builder()->GetShape(sorted); + if (!shape_or.ok()) { + return op.emitError(shape_or.status().ToString()); + } + + xla::Shape& shape = shape_or.ValueOrDie(); + if (!shape.IsTuple()) { + value_map[op.getResult(0)] = sorted; + return success(); + } + // MLIR's sort supports multiple returns, untuple all the results of XLA's. for (auto it : llvm::enumerate(op.getResults())) { - value_map[it.value()] = xla::GetTupleElement(tupled, it.index()); + value_map[it.value()] = xla::GetTupleElement(sorted, it.index()); } return success(); } @@ -1169,9 +1180,9 @@ StatusOr CreateArrayLiteralFromAttr(ElementsAttr attr, } xla::Layout ExtractLayout(mlir::Operation* op, int rank) { - if (auto attr = - op->getAttrOfType("minor_to_major")) { + if (auto attr = GetLayoutFromMlirHlo(op)) { llvm::SmallVector minor_to_major; + DCHECK_EQ(rank, attr.size()); minor_to_major.reserve(attr.size()); for (const llvm::APInt& i : attr) { minor_to_major.push_back(i.getZExtValue()); @@ -1726,4 +1737,8 @@ Status ConvertMlirHloToHlo( return Status::OK(); } +DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op) { + return op->getAttrOfType("minor_to_major"); +} + } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 4ca3e586128..a727f60084c 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -64,6 +64,8 @@ llvm::Optional<::xla::XlaOp> CreateXlaOperator( mlir::Operation* op, llvm::DenseMap* value_lowering); +mlir::DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op); + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 754b14f4b13..3523dba9917 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -1,12 +1,19 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + tags_override = { + "hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags(), + }, test_file_exts = [ "mlir", "hlotxt", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt deleted file mode 100644 index 781e203510b..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s - -HloModule TestModule - -// CHECK: func @TestComputation - -FusedComputation { - // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} add(x, x) -} - -ENTRY TestComputation { - x = f32[3, 2]{0,1} parameter(0) - ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation -} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/gpu_ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/gpu_ops.mlir new file mode 100644 index 00000000000..83c156554cd --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/gpu_ops.mlir @@ -0,0 +1,33 @@ +// RUN: xla-opt -split-input-file "-xla-hlo-to-lhlo-with-xla=platform=CUDA" %s +//// | FILECHECK_OPTS="" FileCheck --enable-var-scope %s + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<3x3xi32> +// CHECK-SAME: %[[ARG1:.*]]: memref<2xi32> +// CHECK-SAME: %[[ARG2:.*]]: memref<2x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: memref<36xi8> {lmhlo.alloc = 0 +// CHECK: %[[VIEW0:.*]] = std.view %[[ARG3]]{{.*}} : memref<36xi8> to memref3x3xi32> +// CHECK: "lmhlo.copy"(%[[ARG0]], %[[VIEW0]]) +// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32> +// CHECK: "lmhlo.scatter"(%[[VIEW0]], %[[ARG1]], %[[ARG2]], %[[VIEW1]]) +// CHECK: mhlo.add +// CHECK: indices_are_sorted = false +// CHECK: index_vector_dim = 1 : i64 +// CHECK: inserted_window_dims = dense<0> : tensor<1xi64> +// CHECK: scatter_dims_to_operand_dims = dense<0> : tensor<1xi64> +// CHECK: update_window_dims = dense<1> : tensor<1xi64> +// CHECK: unique_indices = false +func @main(%operand:tensor<3x3xi32>, %indices: tensor<2xi32>, %updates: tensor<2x3xi32>) -> tensor<3x3xi32> { + %result = "mhlo.scatter"(%operand, %indices, %updates) ( { + ^bb0(%x: tensor, %y : tensor): + %result = "mhlo.add"(%x, %y): (tensor, tensor) -> tensor + "mhlo.return"(%result) : (tensor) -> () + }) { scatter_dimension_numbers = {index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<1> : tensor<1xi64>}, + indices_are_sorted = false, + unique_indices = false} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %result : tensor<3x3xi32> +} + diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt new file mode 100644 index 00000000000..7c42100e433 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -0,0 +1,48 @@ +// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s + +HloModule TestModule + +// CHECK: func @TestComputation + +FusedComputation { + // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} add(x, x) +} + +ENTRY TestComputation { + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation +} + +// ----- + +HloModule ScatterModule + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +// CHECK: func @main +// CHECK: "lmhlo.scatter" +// CHECK: ^bb0(%[[ARG5:.*]]: tensor, %[[ARG6:.*]]: tensor): +// CHECK: "mhlo.return"(%[[ARG6]]) +// CHECK: indices_are_sorted = false +// CHECK: index_vector_dim = 1 : i64 +// CHECK: inserted_window_dims = dense<0> : tensor<1xi64> +// CHECK: scatter_dims_to_operand_dims = dense<0> : tensor<1xi64> +// CHECK: update_window_dims = dense<1> : tensor<1xi64> +// CHECK: unique_indices = false +// CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> () +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter_op = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index fcbc5ebd337..d3594c30431 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -342,7 +342,7 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< } // CHECK-LABEL: fusedBatchNormGradV3_Training -func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> @@ -350,10 +350,11 @@ func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> + // CHECK: return %[[x_backprop]] + // CHECK-SAME: tensor<8x8x8x8xf32> - %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - return %0#0 : tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) + return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> } // CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision @@ -1020,6 +1021,13 @@ func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { return %0: tensor<1xi32> } +// CHECK-LABEL: func @identityN +func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { + // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> + %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) + return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> +} + // CHECK-LABEL: func @stopgradient func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: return %arg0 : tensor<1xi32> @@ -2633,6 +2641,14 @@ func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tens return %0 : tensor<1x4xi32> } +// CHECK-LABEL: slice_mhlo_sizes +func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { + // CHECK-NOT: "tf.Slice" + %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> + return %1 : tensor<1x512x4xf32> +} + // CHECK-LABEL: slice_variable_start_negative_one_size func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[RESULT:.*]] = "tf.Slice" diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index c078191d170..b857b2963f9 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -963,6 +963,22 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0 // CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1 +// ----- + +// CHECK: HloModule +func @main(%input0: tensor<16x16xf32>) { + %0 = "mhlo.sort"(%input0) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>) -> (tensor<16x16xf32>) + return +} + +// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { +// CHECK: ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT + +// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] // ----- diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index d541270f671..ff22e74f1c4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -188,22 +188,20 @@ Type GetSumAccumulationType(Type input_type) { return input_type; } -// Returns axis in HLO format from TF elements attr with exactly one element -// containing axis in the TensorFlow format. TensorFlow format supports negative -// indexing unlike HLO. -static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank, +// Returns axis in HLO format from TF elements attr with exactly one element or +// is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow +// format supports negative indexing unlike HLO. +static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, Builder *b) { - SmallVector index(attr.getType().getRank(), 0); - int64_t axis = attr.getValue(index).getInt(); - if (axis < 0) { - axis += rank; + IntegerAttr intAttr = attr.dyn_cast_or_null(); + if (auto elementAttr = attr.dyn_cast_or_null()) { + SmallVector index(elementAttr.getType().getRank(), 0); + intAttr = elementAttr.getValue(index); } - return b->getI64IntegerAttr(axis); -} -static IntegerAttr GetHLOAxisFromTFAxis(IntegerAttr attr, int64_t rank, - Builder *b) { - int64_t axis = attr.getInt(); + assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis"); + + int64_t axis = intAttr.getInt(); if (axis < 0) { axis += rank; } @@ -1707,6 +1705,17 @@ class ConvertEinsumOp : public OpRewritePattern { } }; +// Bypasses IdentityN op. +class ConvertIdentityNOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::IdentityNOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + template class ConvertFFTOp : public OpRewritePattern { public: @@ -1881,11 +1890,27 @@ class ConvertFusedBatchNormGradBase } x_backprop = rewriter.create(loc, x_backprop, act_ele_type); - // It doesn't matter what values we provide for the last 2 results. - rewriter.replaceOp(op, - {/*x_backprop=*/x_backprop, - /*scale_backprop=*/scale_backprop, - /*offset_backprop=*/offset_backprop, op.x(), op.x()}); + Value last_val[2]; + if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { + // It doesn't matter what values we provide for the last 2 results. + last_val[0] = last_val[1] = op.x(); + } else { + auto const_val = rewriter.create( + op.getLoc(), + DenseElementsAttr::get( + RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))), + 0.0)); + auto maybe_cast = [&](Value val, Type t) -> Value { + if (val.getType() == t) return val; + return rewriter.create(op.getLoc(), t, val); + }; + last_val[0] = maybe_cast(const_val, op.getResult(3).getType()); + last_val[1] = maybe_cast(const_val, op.getResult(4).getType()); + } + rewriter.replaceOp( + op, {/*x_backprop=*/x_backprop, + /*scale_backprop=*/scale_backprop, + /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]}); return success(); } }; @@ -2023,13 +2048,25 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { /*reserve_space_1=*/reserve_space_1, /*reserve_space_2=*/batch_variance}); } else { // TF::FusedBatchNormV3Op - // FusedBatchNormV3 expects a 5th output, but the output is unused; it - // doesn't matter what we pass there. + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + op.getResult(5).getType().template cast(); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = RankedTensorType::get( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, /*batch_variance=*/corrected_variance, /*reserve_space_1=*/reserve_space_1, /*reserve_space_2=*/batch_variance, - /*reserve_space_3=*/op.x()}); + /*reserve_space_3=*/dummy_const}); } } else { // Inference case. auto bn_train_op = rewriter.create( @@ -6117,13 +6154,14 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, - ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, - ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, - ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, - ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, - ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertIdentityNOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp, + ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, + ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 52bbbf6f9da..5baef3b4afd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -53,7 +53,7 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< def CastElementsToI64Elements : NativeCodeCall< "hlo::ConvertElementsAttr(" - "$0, $_builder.getIntegerType(64)).cast()">; + "$0.cast(), $_builder.getIntegerType(64)).cast()">; def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, @@ -253,7 +253,7 @@ def IsShapedTensor // the conversion, original Concat op operands still refers to the old ops even // if HLO constant op is introduced as an replacement for the TensorFlow // Constant op. -def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), +def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), (HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; @@ -262,7 +262,7 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), // CollectivePermute op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)), +def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), (HLO_CollectivePermuteOp $input, (CastElementsToI64Elements $source_target_pairs))>; @@ -270,7 +270,7 @@ def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)), // CrossReplicaSum op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), +def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), (HLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; @@ -278,7 +278,7 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), // All2All op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; //===----------------------------------------------------------------------===// @@ -308,7 +308,7 @@ def : Pat<(TF_IFFTOp:$res $input), // indexing to the HLO format. def LegalizeGatherV2 : Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, - (TF_ConstOp $axis), $batch_dims), + (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), (HLO_TorchIndexSelectOp $params, $indices, (GetHLOAxisFromTFAxis $axis, $params), (GetHLOAxisFromTFAxis $batch_dims, $indices))>; @@ -318,16 +318,16 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0, " # column # " )">; + "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0, " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0)">; + "GetInteriorPadding($0.cast())">; -def : Pat<(TF_PadV2Op $input, (TF_ConstOp $padding), $c), +def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), (HLO_PadOp $input, $c, (SliceDenseIntElementsAttrColumn2D<"0"> $padding), (SliceDenseIntElementsAttrColumn2D<"1"> $padding), @@ -528,7 +528,7 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall< "&$_builder)">; def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, - (TF_ConstOp $slice_sizes)), + (ConstantLikeMatcher AnyAttr:$slice_sizes)), (HLO_DynamicSliceOp $input, (CastToI64AndUnpackTensor $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), @@ -560,9 +560,9 @@ def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; -def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// @@ -603,7 +603,7 @@ foreach Mapping = [ def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), (HLO_ConvertOp $arg)>; -def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp $permutation)), +def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), (HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; // Result of the following ops changing tensor shape needs to have static @@ -707,7 +707,7 @@ def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)"> def HasValidGatherDims : Constraint>; -def : Pat<(TF_XlaGatherOp $operand, $start_indices, (TF_ConstOp $slice_sizes), +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), (HLO_GatherOp $operand, $start_indices, (ToGatherDimNumsAttr $dimension_numbers), diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 6e4a2495a4f..5098e581fd6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -161,7 +161,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -229,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 9cf161fb2ae..0efb8a16ba3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -134,8 +134,8 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, return Status::OK(); } -// This pass take a MLIR HLO module, convert it to XLA to perform the HLO -// optimization pipeline for the required platform, and then convert back to +// This pass takes an MLIR HLO module, converts it to XLA to perform the HLO +// optimization pipeline for the required platform, and then converts it back to // MLIR LHLO. class XlaHloToLhloPass : public PassWrapper> { @@ -392,7 +392,6 @@ StatusOr LhloDialectEmitter::EmitFusionOp( } } - LOG(ERROR) << instr->GetModule()->ToString(); builder_.restoreInsertionPoint(after_fusion); return fusion; } @@ -401,6 +400,49 @@ Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { return EmitFusionOp(instr).status(); } +StatusOr +LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) { + auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr); + + const ::xla::ScatterDimensionNumbers& xla_scatter_dim = + scatter_instr->scatter_dimension_numbers(); + auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get( + getI64DenseElementsAttr(xla_scatter_dim.update_window_dims()), + getI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()), + getI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()), + builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()), + module_.getContext()); + return scatter_dimension_numbers; +} + +StatusOr LhloDialectEmitter::EmitScatterOp( + HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto scatter, + CreateOpWithoutAttrs(instr)); + + // copy attributes + auto* scatter_instr = ::xla::Cast<::xla::HloScatterInstruction>(instr); + + TF_ASSIGN_OR_RETURN(auto scatter_dimension_numbers, + GetScatterDimensionNumbers(instr)); + scatter.scatter_dimension_numbersAttr(scatter_dimension_numbers); + scatter.indices_are_sortedAttr( + builder_.getBoolAttr(scatter_instr->indices_are_sorted())); + scatter.unique_indicesAttr( + builder_.getBoolAttr(scatter_instr->unique_indices())); + + // import update computation as region + TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( + *scatter_instr->called_computations()[0], &scatter.update_computation(), + &builder_)); + + return scatter; +} + +Status LhloDialectEmitter::HandleScatter(HloInstruction* instr) { + return EmitScatterOp(instr).status(); +} + StatusOr LhloDialectEmitter::GetOrCreateArrayView( const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::ShapeIndex& shape_index) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index a57db3cb67e..97a9b17e81d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" namespace mlir { @@ -44,11 +46,20 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitFusionOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitScatterOp(::xla::HloInstruction* instr); + ::xla::StatusOr GetScatterDimensionNumbers( + ::xla::HloInstruction* instr); private: template ::xla::StatusOr CreateOpWithoutAttrs(::xla::HloInstruction* instr); + template + DenseIntElementsAttr getI64DenseElementsAttr(const T& container) { + return builder_.getI64TensorAttr( + {container.data(), static_cast(container.size())}); + } + tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final; // Computation parameters don't need any specific handling when they are @@ -59,6 +70,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final; + tensorflow::Status HandleScatter(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in // `current_shape`, and reconstruct a matching lmhlo::TupleOp. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 1dfcf88e654..ecffd276f10 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -327,7 +327,6 @@ tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -393,7 +392,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -416,7 +414,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -639,7 +636,6 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -696,7 +692,6 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -758,6 +753,7 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notsan", # TODO(b/171000704): data race ], deps = [ ":xla_test", @@ -1018,7 +1014,6 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1089,7 +1084,6 @@ tf_xla_py_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1391,7 +1385,6 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1531,6 +1524,7 @@ tf_xla_py_test( name = "scatter_nd_op_test", size = "medium", srcs = ["scatter_nd_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1548,7 +1542,6 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, # Times out in fastbuild mode. @@ -1705,6 +1698,7 @@ tf_cuda_cc_test( name = "unary_ops_composition_test", srcs = ["unary_ops_composition_test.cc"], tags = [ + "no_cuda_asan", # TODO(b/171317888): re-enable. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ] + tf_cuda_tests_tags(), deps = [ @@ -1790,7 +1784,6 @@ tf_xla_py_test( name = "fake_quant_ops_test", size = "medium", srcs = ["fake_quant_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 3adb169e7f0..04531108b70 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.zeros([2, 2], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2])) + @test_util.disable_mlir_bridge("Error messages differ") def testRank3InvalidShape1(self): indices = np.zeros([3, 2, 2], np.int32) updates = np.zeros([2, 2, 2], np.int32) @@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase): "Must have updates.shape"): self._runScatterNd(indices, updates, [2, 2, 2]) + @test_util.disable_mlir_bridge("Error messages differ") def testRank3InvalidShape2(self): indices = np.zeros([2, 2, 1], np.int32) updates = np.zeros([2, 2], np.int32) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index e7183cd7e18..a82c1c485b9 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -483,7 +483,7 @@ tf_cuda_cc_test( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 2804a381e0c..8fbe0f4ceb9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -60,8 +60,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { + using absl::StrAppend; using absl::StrCat; +using ::tensorflow::tensorrt::segment::ClusterProperty; +using ::tensorflow::tensorrt::segment::NodePtrCompare; +using ::tensorflow::tensorrt::segment::Segment; namespace { @@ -125,15 +129,21 @@ bool ShallKeepControlEdgeFrom(const Node* input_node) { // Function to get subsegment information structure. Status GetEngineInfo(const Graph* g, const grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, + const Segment& segment, const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { std::vector subgraph_nodes; // Topologically sorted nodes. std::set added_const_nodes; // Used to prevent double insertion. + + const ClusterProperty& segment_property = segment.property; + const std::set& segment_nodes = segment.nodes; + // The device assignment accumulated from the compatible device assignments // for the nodes in the segment. - DeviceNameUtils::ParsedName segment_device; + const DeviceNameUtils::ParsedName segment_device = + segment_property.DeviceName(); + info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize(); // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -146,18 +156,6 @@ Status GetEngineInfo(const Graph* g, ++it) { const Node* node = *it; if (segment_nodes.count(node) == 0) continue; - - absl::optional new_segment_device = - MergeIfCompatible(segment_device, GetDeviceName(node)); - if (!new_segment_device.has_value()) { - // The segmenter should guarantee that nodes in the same segment have - // compatible device assignments. - return errors::Internal( - "segment nodes have incompatible device assignments: ", - DeviceNameUtils::ParsedNameToString(segment_device), " vs ", - GetDeviceName(node), " to node ", node->name()); - } - segment_device = *new_segment_device; subgraph_nodes.push_back(node); const int node_id = node->id(); @@ -332,7 +330,7 @@ void UpdateToEngineNode(const std::vector& infos, // invocation of CreateTRTNode(). Status CreateTRTNode(const ConversionParams& params, const std::vector& infos, int pos, - int max_batch_size, Graph* graph, + int default_max_batch_size, Graph* graph, std::vector* engine_nodes) { const auto& info = infos.at(pos); std::vector input_shape_protos; @@ -427,6 +425,11 @@ Status CreateTRTNode(const ConversionParams& params, (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; + + int max_batch_size = info.max_batch_size.has_value() + ? info.max_batch_size.value() + : default_max_batch_size; + if (info.engine_type == EngineInfo::EngineType::TRTStatic) { std::pair device_allocator = GetDeviceAndAllocator(params, info); @@ -443,6 +446,7 @@ Status CreateTRTNode(const ConversionParams& params, cudaSetDevice(cuda_device_id); auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name); + // Create static engines with precision_mode fp32/fp16. TrtUniquePtrType engine; TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( @@ -486,6 +490,7 @@ Status CreateTRTNode(const ConversionParams& params, .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("max_batch_size", max_batch_size) .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) .Attr("_use_implicit_batch", params.use_implicit_batch) @@ -738,7 +743,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { segment_options.allow_dynamic_non_batch_dim = AllowDynamicNonBatchDimension(params); - segment::SegmentNodesVector initial_segments; + segment::SegmentVector initial_segments; TrtNodeValidator validator(static_graph_properties, params.precision_mode, params.use_calibration, params.use_implicit_batch); TF_RETURN_IF_ERROR(segment::SegmentGraph( @@ -755,14 +760,11 @@ Status ConvertAfterShapes(const ConversionParams& params) { // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - float total_num_nodes_in_segments = 0.; std::vector engine_segments; engine_segments.reserve(initial_segments.size()); std::vector reverse_topo_order; GetPostOrder(graph, &reverse_topo_order); - size_t total_engine_bytes_size = 0; - std::vector engine_bytes_size; - segment::SegmentNodesVector converted_segments; + segment::SegmentVector converted_segments; converted_segments.reserve(initial_segments.size()); string engine_name_prefix = StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_"); @@ -782,6 +784,9 @@ Status ConvertAfterShapes(const ConversionParams& params) { curr_engine.use_calibration = params.use_calibration; curr_engine.maximum_cached_engines = params.max_cached_engines; curr_engine.allow_build_at_runtime = params.allow_build_at_runtime; + if (!curr_engine.max_batch_size.has_value()) { + curr_engine.max_batch_size = params.max_batch_size; + } status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def, &graph, curr_engine.engine_name); @@ -793,9 +798,6 @@ Status ConvertAfterShapes(const ConversionParams& params) { continue; } - engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); - total_engine_bytes_size += engine_bytes_size.back(); - total_num_nodes_in_segments += curr_segment.size(); engine_segments.push_back(std::move(curr_engine)); converted_segments.push_back(std::move(curr_segment)); @@ -834,20 +836,16 @@ Status ConvertAfterShapes(const ConversionParams& params) { engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); - // Partition the workspace size by the average of node ratio and segment - // graphdef size - engine.max_workspace_size_bytes = - params.max_workspace_size_bytes * - (engine_bytes_size.at(i) / total_engine_bytes_size + - converted_segments.at(i).size() / total_num_nodes_in_segments) / - 2.0; + // TODO(b/170762693): implement the heuristic to calculate + // max_workspace_size_bytes. + engine.max_workspace_size_bytes = params.max_workspace_size_bytes; VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " << engine.engine_name; auto status = CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph, &engine_nodes); string msg = StrCat("segment ", i, " consisting of ", - converted_segments.at(i).size(), " nodes by ", + converted_segments.at(i).nodes.size(), " nodes by ", engine.engine_name); if (status.ok()) { LOG(INFO) << "Replaced " << msg << "."; @@ -859,7 +857,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const Node* node : converted_segments.at(i)) { + for (const Node* node : converted_segments.at(i).nodes) { StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; @@ -868,7 +866,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. if (status.ok()) { - for (const Node* node : converted_segments.at(i)) { + for (const Node* node : converted_segments.at(i).nodes) { graph.RemoveNode(const_cast(node)); } } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index d09485c35c7..99d7730bfe5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1599,12 +1599,19 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, return Status::OK(); } -Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, - const nvinfer1::Dims& dims, - const bool validation_only, - nvinfer1::ITensor** tensor, - const NodeDef& node_def, - absl::optional op_instance) { +// Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' +// (which doesn't contain the batch dimension). +// +// If validation_only is true, it doesn't do the conversion but only do some +// minimum validation for the eligibility of the conversion, and *tensor will +// be set to nullptr. +Status PrepareTensorForShape(Converter* converter, + const TRT_TensorOrWeights& input, + const nvinfer1::Dims& dims, + const bool validation_only, + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance) { const nvinfer1::Dims input_dims = input.GetTrtDims(); // If one of input_dims and dims doesn't have static shape, it means some of // the dims are unknown or need to be inferred. And we don't do further checks @@ -1628,29 +1635,32 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, return Status::OK(); } + TFTRT_RETURN_ERROR_IF_NULLPTR(converter, "converter is nullptr"); if (input.is_tensor()) { if (DimsEqual(input_dims, dims)) { *tensor = input.tensor(); } else { nvinfer1::IShuffleLayer* layer = - this->network()->addShuffle(*input.tensor()); + converter->network()->addShuffle(*input.tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); SetLayerName(layer, node_def, "shuffle", op_instance); layer->setReshapeDimensions(dims); - MarkQuantizationRangesAsInferrable(input.tensor(), layer->getOutput(0)); + converter->MarkQuantizationRangesAsInferrable(input.tensor(), + layer->getOutput(0)); *tensor = layer->getOutput(0); } } else { - *tensor = CreateConstantLayer(input.weights(), dims); + *tensor = converter->CreateConstantLayer(input.weights(), dims); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape"); - if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) { + if (converter->precision_mode() == TrtPrecisionMode::INT8 && + !converter->use_calibration()) { // If we are in int8 mode and not calibrating, we need to explicitly set a // quantization range for the output tensor of the IConstantLayer. Here we // set the range to [min(weights), max(weights)]. float min_range = 0.0f; float max_range = 0.0f; TF_RETURN_IF_ERROR( - GetWeightRange(input.weights(), &min_range, &max_range)); + converter->GetWeightRange(input.weights(), &min_range, &max_range)); // Avoid setting range to 0 because TRT will throw an error. If the // weights are zero then the range doesn't matter: using 127.0f should // ensure the quantized weight will be exactly zero. @@ -1658,7 +1668,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, min_range = -127.0f; max_range = 127.0f; } - ProvideQuantizationRange(*tensor, min_range, max_range); + converter->ProvideQuantizationRange(*tensor, min_range, max_range); } } return Status::OK(); @@ -2520,9 +2530,9 @@ Status ConvertReshape(OpConverterParams* params) { // Perform the conversion. nvinfer1::ITensor* output_tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, output_nonbatch_dims, params->validation_only, - &output_tensor, params->node_def)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, input_tensor, output_nonbatch_dims, + params->validation_only, &output_tensor, params->node_def)); if (params->validation_only) return Status::OK(); // Record the conversion result. @@ -2564,9 +2574,9 @@ Status ConvertExpandDims(OpConverterParams* params) { // Reshape tensor. nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, new_dims, /*validation_only=*/false, &output_tensor, - params->node_def)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, input_tensor, new_dims, /*validation_only=*/false, + &output_tensor, params->node_def)); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2672,9 +2682,9 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, nvinfer1::Dims new_dims; VLOG(2) << "input_dims" << input_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); - TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, - /*validation_only=*/false, output, - params->node_def)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(input), new_dims, + /*validation_only=*/false, output, params->node_def)); return Status::OK(); } @@ -2786,9 +2796,9 @@ Status ConvertStridedSliceHelper( nvinfer1::ITensor* tensor = layer->getOutput(0); // Reshape for shrink_axis. if (final_shape) { - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor, node_def, op_instance)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(tensor), *final_shape, + /*validation_only=*/false, &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); return Status::OK(); @@ -2890,9 +2900,9 @@ Status ConvertStridedSliceHelper( // Start conversion. nvinfer1::ITensor* tensor = input.tensor(); if (need_reshape) { - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, reshape_dims, /*validation_only=*/false, &tensor, node_def, - op_instance)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, input, reshape_dims, /*validation_only=*/false, + &tensor, node_def, op_instance)); } if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( @@ -2914,9 +2924,9 @@ Status ConvertStridedSliceHelper( } // Reshape for shrink_axis. if (final_shape) { - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor, node_def, op_instance)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(tensor), *final_shape, + /*validation_only=*/false, &tensor, node_def, op_instance)); } else if (need_reshape) { // Restore reshape. // Calculate output dimensions @@ -2937,9 +2947,9 @@ Status ConvertStridedSliceHelper( nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, /*ignore_first_dim=*/true)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false, - &tensor, node_def, op_instance)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(tensor), new_dims, + /*validation_only=*/false, &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); @@ -4121,17 +4131,18 @@ Status ConvertBiasAdd(OpConverterParams* params) { // Convert input to a TRT tensor nvinfer1::ITensor* input_tensor{nullptr}; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), input_shape, params->validation_only, &input_tensor, - node_def, - /*op_instance=*/0)); + TF_RETURN_IF_ERROR(PrepareTensorForShape(params->converter, inputs.at(0), + input_shape, params->validation_only, + &input_tensor, node_def, + /*op_instance=*/0)); // Finally, reshape bias. Since the bias is usually a constant, this will // normally happen at conversion-time. nvinfer1::ITensor* bias_tensor{nullptr}; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), bias_shape, params->validation_only, &bias_tensor, node_def, - /*op_instance=*/1)); + TF_RETURN_IF_ERROR(PrepareTensorForShape(params->converter, inputs.at(1), + bias_shape, params->validation_only, + &bias_tensor, node_def, + /*op_instance=*/1)); VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape); if (params->validation_only) return Status::OK(); @@ -4368,12 +4379,12 @@ Status ConvertBinary(OpConverterParams* params) { nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; // This will also convert constants to tensors, and set quantization ranges. - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, params->validation_only, &tensor_l, - node_def, /*op_instance=*/0)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, params->validation_only, &tensor_r, - node_def, /*op_instance=*/1)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, operand_l, broadcasted_dims_l, params->validation_only, + &tensor_l, node_def, /*op_instance=*/0)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, operand_r, broadcasted_dims_r, params->validation_only, + &tensor_r, node_def, /*op_instance=*/1)); if (params->validation_only) return Status::OK(); // Add ElementWise layer. @@ -4674,9 +4685,9 @@ Status ConvertPack(OpConverterParams* params) { input_index)); } } else { - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, expanded_dims, params->validation_only, &expanded_tensor, - node_def, input_index)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, input, expanded_dims, params->validation_only, + &expanded_tensor, node_def, input_index)); } if (!params->validation_only) { expanded_tensors.push_back(expanded_tensor); @@ -5247,8 +5258,9 @@ Status ConvertGather(OpConverterParams* params) { trt_gather_output_dims.d[trt_axis] = 1; ++trt_gather_output_dims.nbDims; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(output_tensor), trt_gather_output_dims, + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(output_tensor), + trt_gather_output_dims, /*validation_only=*/false, &output_tensor, node_def)); } @@ -5265,9 +5277,9 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, while (input_dim.nbDims < 3) { input_dim.d[input_dim.nbDims++] = 1; } - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(tensor_a), input_dim, /*validation_only=*/false, - &tensor_a, node_def, /*op_instance=*/0)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(tensor_a), input_dim, + /*validation_only=*/false, &tensor_a, node_def, /*op_instance=*/0)); // FC layer will transpose weights, so we need to pre-transpose. TRT_ShapedWeights weights(weights_b.TrtDType()); @@ -5290,9 +5302,9 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, // Reshape output to 1D - this will be a no-op unless using int8 precision. auto output_dim = output_tensor->getDimensions(); output_dim.nbDims = 1; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false, - &output_tensor, node_def, /*op_instance=*/1)); + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(output_tensor), output_dim, + /*validation_only=*/false, &output_tensor, node_def, /*op_instance=*/1)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5344,10 +5356,9 @@ Status ConvertMatMulHelper(OpConverterParams* params, const auto get_matrix_op = [](nvinfer1::ITensor* in, bool transpose) -> nvinfer1::MatrixOperation { - return (in->getDimensions().nbDims < 2) - ? nvinfer1::MatrixOperation::kVECTOR - : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE - : nvinfer1::MatrixOperation::kNONE; + return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR + : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE + : nvinfer1::MatrixOperation::kNONE; }; // If the MatMul operand is a constant, applies transposes at conversion-time @@ -5466,12 +5477,12 @@ Status ConvertBatchMatMul(OpConverterParams* params) { params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r)); nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, - node_def)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, - node_def)); + TF_RETURN_IF_ERROR( + PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l, + params->validation_only, &tensor_l, node_def)); + TF_RETURN_IF_ERROR( + PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r, + params->validation_only, &tensor_r, node_def)); if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l), @@ -5552,8 +5563,8 @@ Status ConvertArgMinMax(OpConverterParams* params) { nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &new_dims)); nvinfer1::ITensor* output_tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(output_indices_tensor), new_dims, + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(output_indices_tensor), new_dims, /*validation_only=*/false, &output_tensor, node_def)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -5713,12 +5724,12 @@ Status ConvertSquaredDifference(OpConverterParams* params) { params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r)); nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, - node_def)); - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, - node_def)); + TF_RETURN_IF_ERROR( + PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l, + params->validation_only, &tensor_l, node_def)); + TF_RETURN_IF_ERROR( + PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r, + params->validation_only, &tensor_r, node_def)); if (params->validation_only) return Status::OK(); // Subtract x - y. @@ -5894,8 +5905,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) { DebugString(*in_tensor)); } --dims.nbDims; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(in_tensor), dims, + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(in_tensor), dims, /*validation_only=*/false, out_tensor, node_def, output_index)); return Status::OK(); }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 4a84793e254..35593143332 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -92,6 +92,8 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), + max_batch_size(absl::nullopt), + maximum_cached_engines(0), precision_mode(TrtPrecisionMode::FP32), use_calibration(true), allow_build_at_runtime(true) {} @@ -108,6 +110,7 @@ struct EngineInfo { enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; EngineType engine_type; int64 max_workspace_size_bytes; + absl::optional max_batch_size; int maximum_cached_engines; TrtPrecisionMode precision_mode; bool use_calibration; @@ -526,19 +529,6 @@ class Converter { const NodeDef& node_def, absl::string_view sub_op_name = ""); - // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' - // (which doesn't contain the batch dimension). - // - // If validation_only is true, it doesn't do the conversion but only do some - // minimum validation for the eligibility of the conversion, and *tensor will - // be set to nullptr. - Status PrepareTensorForShape(const TRT_TensorOrWeights& input, - const nvinfer1::Dims& dims, - const bool validation_only, - nvinfer1::ITensor** tensor, - const NodeDef& node_def, - absl::optional op_instance = absl::nullopt); - // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1, // and/or permuting the dimensions. The new shape is derived from the shape of // the input tensor according to the slices and size_for_added_dims arguments. @@ -603,6 +593,10 @@ class Converter { nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims); + // Gets the min and max value in a TRT_ShapedWeights + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const; + private: Converter(TrtPrecisionMode precision_mode, bool use_calibration, nvinfer1::ILogger* trt_logger, const bool use_implicit_batch); @@ -627,10 +621,6 @@ class Converter { void PropagateQuantizationRanges(); - // Gets the min and max value in a TRT_ShapedWeights - Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, - float* out_max) const; - // Registered op converters by op type. std::unordered_map op_registry_; @@ -684,6 +674,21 @@ class Converter { friend class OpConverterTest; }; +// Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' +// (which doesn't contain the batch dimension). +// +// If validation_only is true, it doesn't do the conversion but only do some +// minimum validation for the eligibility of the conversion, and *tensor will +// be set to nullptr. +// If validation_only is false converter must not be nullptr. +Status PrepareTensorForShape(Converter* converter, + const TRT_TensorOrWeights& input, + const nvinfer1::Dims& dims, + const bool validation_only, + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance = absl::nullopt); + // Return OK if the broadcast scheme is supported and compute the shapes after // broadcasting. check_feasibility can be set to false in cases where dimensions // do not need to match exactly (as in the case of BatchMatMulV2). diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 86e6f0dd345..a33d5c28cb2 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -949,9 +949,9 @@ void TestPrepareTensorForShape( NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {}); for (bool validation_only : {false, true}) { - const Status status = converter->PrepareTensorForShape( - input, GetTestDims(reshape_dims), validation_only, &output_tensor, - dummy_node_def); + const Status status = + PrepareTensorForShape(converter, input, GetTestDims(reshape_dims), + validation_only, &output_tensor, dummy_node_def); if (expected_code == error::OK) { TF_EXPECT_OK(status); if (validation_only) { @@ -4275,72 +4275,72 @@ TEST_P(OpConverterTest1, ConvertConv2D) { // Ok. std::vector ok_params = { - // Basic - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 2}, - /*expected_output=*/{1, 1, 0, 1}}, - // SAME padding (Asymmetric) - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"SAME", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 3}, - /*expected_output=*/{1, 1, -2, 0, 1, -4}}, - // SAME padding (Symmetric) - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 3, 1, 1}, - /*filter=*/{-1, 0, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"SAME", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 3}, - /*expected_output=*/{1, 2, -1, 3, 1, -3}}, - // NHWC - TestParams{/*input_dims=*/{1, 2, 3, 1}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NHWC", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2, 1}, - /*expected_output=*/{1, 1, 0, 1}}, - // Dilated - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 2}, - /*expected_output_dims=*/{1, 1, 2, 1}, - /*expected_output=*/{2, 1}}, - // Strided - TestParams{/*input_dims=*/{1, 1, 2, 4}, - /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 2}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 2}, - /*expected_output=*/{1, 0, 1, 3}}, + // Basic + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, + // NHWC + TestParams{/*input_dims=*/{1, 2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*expected_output_dims=*/{1, 1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, }; for (int i = 0; i < ok_params.size(); i++) { @@ -6812,20 +6812,20 @@ template void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; - std::vector> params { - { - /*input_dims=*/{1, 2, 1}, // H, W, C - /*output_resize_dims=*/{2, 3}, // H_out, W_out - /*input_values=*/CastTestVector({2.0f, -1.0f}), - /*align_corners=*/false, - /*expected_output_dims=*/{2, 3, 1}, // H, W, C - /*expected_nearest_output_values=*/ - CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), - /*expected_bilinear_output_values=*/ - CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), - }, - { - /*input_dims=*/{1, 2, 1}, // H, W, C + std::vector> params{ + { + /*input_dims=*/{1, 2, 1}, // H, W, C + /*output_resize_dims=*/{2, 3}, // H_out, W_out + /*input_values=*/CastTestVector({2.0f, -1.0f}), + /*align_corners=*/false, + /*expected_output_dims=*/{2, 3, 1}, // H, W, C + /*expected_nearest_output_values=*/ + CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), + /*expected_bilinear_output_values=*/ + CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), + }, + { + /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out /*input_values=*/CastTestVector({2.0f, -1.0f}), /*align_corners=*/true, @@ -6834,8 +6834,7 @@ void TestConvertResize(OpConverterTest* test) { CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), /*expected_bilinear_output_values=*/ CastTestVector({2.0f, 0.5f, -1.0f, 2.0f, 0.5f, -1.0f}), - } - }; + }}; // This use case is not supported as of TRT version 7.1 #if IS_TRT_VERSION_GE(7, 1, 0, 0) diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index 2527fe9b910..a2a41f5a03c 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -37,6 +37,7 @@ REGISTER_OP("TRTEngineOp") .Attr("OutT: list({int8,float16,float32,int32})") .Attr("input_shapes: list(shape) = []") .Attr("max_cached_engines_count: int = 1") + .Attr("max_batch_size: int = 1") .Attr("workspace_size_bytes: int") .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") .Attr("calibration_data: string = ''") diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 35f966c816f..02ba31fecd2 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" @@ -242,12 +241,6 @@ struct SimpleEdgePtrCompare { } }; -struct NodePtrCompare { - bool operator()(const Node* lhs, const Node* rhs) const { - return lhs->name() < rhs->name(); - } -}; - // Copied from TF ReverseDFS, which only works for Graph. void StableDFS(const SimpleGraph& g, bool reverse, const std::vector& start, @@ -646,10 +639,18 @@ ClusterBatchSize GetClusterBatchSizeForNode( return cluster_batch_size; } + const NodeDef& node_def = node->def(); + if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) { + cluster_batch_size.SetMaxBatchSize( + node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i()); + } + + // As shape inference cannot provide any useful information about the batch + // size, we keep it as missing. if (!graph_properties || !graph_properties->HasInputProperties(node->name())) { VLOG(3) << "doesn't have input property"; - return cluster_batch_size.SetBatchSize(-1); + return cluster_batch_size; } const std::vector& input_properties = @@ -658,8 +659,8 @@ ClusterBatchSize GetClusterBatchSizeForNode( FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties)); DCHECK(optional_leading_shape.has_value()); const TensorShapeProto* leading_shape = optional_leading_shape.value(); - DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2); + VLOG(3) << "set batch size as " << leading_shape->dim(0).size(); return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size()); } @@ -676,21 +677,6 @@ void AddSegmentForNode(const grappler::GraphProperties* graph_properties, segments->emplace_back(node, std::move(property)); } -bool OpBatchSizeExceedMaximumBatchSize( - const grappler::GraphProperties* graph_properties, const Node* node, - bool use_implicit_batch, absl::optional maximum_batch_size) { - ClusterBatchSize cluster_batch_size = - GetClusterBatchSizeForNode(graph_properties, node, use_implicit_batch); - if (cluster_batch_size.HasStaticBatchSize() && - maximum_batch_size.has_value() && - cluster_batch_size.GetStaticBatchSize() > maximum_batch_size.value()) { - VLOG(2) << "OP batch size " << cluster_batch_size.GetStaticBatchSize() - << " max_batch_size " << maximum_batch_size.value(); - return true; - } - return false; -} - } // namespace Status SegmentGraph(const Graph* tf_graph, @@ -698,8 +684,7 @@ Status SegmentGraph(const Graph* tf_graph, const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, - const SegmentOptions& options, - SegmentNodesVector* segments) { + const SegmentOptions& options, SegmentVector* segments) { if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) { return errors::Internal( "Explicit batch mode should allow dynamic non-batch dimensions"); @@ -787,14 +772,6 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << ")"; exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST"); - } else if (OpBatchSizeExceedMaximumBatchSize( - graph_properties, node->tf_node(), - options.use_implicit_batch, options.maximum_batch_size)) { - LOG_WARNING_WITH_PREFIX - << "Implicit batch mode requires OP batch size not larger than " - << "the converter maximum batch size: " - << "(Op name: " << node->name() << ")"; - exclude_node("OP batch size too large"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " @@ -943,18 +920,21 @@ Status SegmentGraph(const Graph* tf_graph, // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map sg_map; for (auto& u : node_segments) { if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { - sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node()); + sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node()); + } + if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) { + sg_map[u.Value()->name()].property = u.Property(); } } // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = itr.second.nodes; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -1042,7 +1022,8 @@ Status SegmentGraph(const Graph* tf_graph, for (const auto& itr : sg_map) { const string& segment_root = itr.first; // Return format does not require set comparator. - std::set segment_nodes(itr.second.begin(), itr.second.end()); + std::set segment_nodes( + itr.second.nodes.begin(), itr.second.nodes.end()); if (VLOG_IS_ON(1) && !segment_nodes.empty()) { string s; for (auto node : segment_nodes) { @@ -1066,8 +1047,7 @@ Status SegmentGraph(const Graph* tf_graph, << num_effective_nodes << " effective nodes, dropping"; continue; } - - segments->emplace_back(segment_nodes); + segments->emplace_back(itr.second.property, segment_nodes); } return Status::OK(); diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index bab6e089fa4..ad41d5eb40f 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -32,10 +33,10 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node pointers. -using SegmentNodesVector = std::vector>; +constexpr char kTftrtOpMaxBatchSizeAttr[] = "_tftrt_op_max_batch_size"; struct SegmentOptions { + // This struct holds per graph segmenting parameters. // Segment must contain at least this many nodes. int minimum_segment_size = 2; bool use_implicit_batch = true; @@ -45,9 +46,28 @@ struct SegmentOptions { // When use_implicit_batch is false or when we are building dynamic engines, // we allow dynamic non-batch dimensions. bool allow_dynamic_non_batch_dim = false; + // The name of the device to put the segment on. std::set exclude_node_list; }; +struct NodePtrCompare { + bool operator()(const Node* lhs, const Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + +struct Segment { + Segment() {} + Segment(const ClusterProperty& property, + const std::set& nodes) + : property(property), nodes(nodes) {} + ClusterProperty property; + std::set nodes; +}; + +// Vector of segments, each entry contains a set of node pointers. +using SegmentVector = std::vector; + // Get the subgraphs of a graph that can be handled by TensorRT. // // @param tf_graph Graph of the network. @@ -63,8 +83,7 @@ Status SegmentGraph(const Graph* tf_graph, const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, - const SegmentOptions& options, - SegmentNodesVector* segments); + const SegmentOptions& options, SegmentVector* segments); } // namespace segment } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index ee406c9743f..12f3e7a5742 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -65,7 +65,7 @@ class SegmentTest : public ::testing::Test { const std::set& input_candidates, const std::set& output_candidates, const std::vector>& expected_segments) { - SegmentNodesVector segments; + SegmentVector segments; TF_EXPECT_OK(SegmentGraph(graph, graph_properties, MakeCandidateFn(candidates), MakeInputEdgeCandidateFn(input_candidates), @@ -82,12 +82,12 @@ class SegmentTest : public ::testing::Test { expected_segments); } - void ValidateSegment(const SegmentNodesVector& segments, + void ValidateSegment(const SegmentVector& segments, const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { std::set segment_node_names; - for (const Node* node : segments[i]) { + for (const Node* node : segments[i].nodes) { segment_node_names.insert(node->name()); } const auto& expected = expected_segments[i]; @@ -490,9 +490,10 @@ TEST_F(SegmentTest, TwoChainsDiffBatchSizes) { RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, /*expected_segments=*/{{"output-0", "const-scalar"}}); + // Converter will create engines based on the static batch size EnableImplicitBatchModeForStaticEngine(1); RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, - /*expected_segments=*/{}); + /*expected_segments=*/{{"output-0", "const-scalar"}}); } TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) { diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.cc b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc index 289a2734183..29882ed6e60 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/core/lib/core/errors.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -29,9 +28,6 @@ namespace { template inline bool CheckIfCompatible(const absl::optional& a, const absl::optional& b) { - if (!a.has_value() && !b.has_value()) { - return true; - } if (a.has_value() && b.has_value()) { return *a == *b; } @@ -52,57 +48,79 @@ template inline absl::optional MergeCompatible(const absl::optional& a, const absl::optional& b) { DCHECK(CheckIfCompatible(a, b)); - return b; + return a.has_value() ? a : b; } } // namespace ClusterBatchSize::ClusterBatchSize() - : has_dynamic_batch_size_(false), static_batch_size_(absl::nullopt) {} + : batch_size_(absl::nullopt), max_batch_size_(absl::nullopt) {} bool ClusterBatchSize::operator==(const ClusterBatchSize& other) { - return has_dynamic_batch_size_ == other.has_dynamic_batch_size_ && - static_batch_size_ == other.static_batch_size_; + return batch_size_ == other.batch_size_ && + max_batch_size_ == other.max_batch_size_; } -int ClusterBatchSize::GetStaticBatchSize() const { - DCHECK(HasStaticBatchSize()); - return static_batch_size_.value(); -} - -// Sets the batch size assuming that the object doesn't have a batch size yet: -// a non-negative input value representing a static batch size. -// a negative input value representing a dynamic batch size. ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) { - if (batch_size < 0) { - has_dynamic_batch_size_ = true; - return *this; - } - static_batch_size_ = MergeCompatible(static_batch_size_, batch_size); + SetBatchSize(static_cast>(batch_size)); return *this; } +ClusterBatchSize& ClusterBatchSize::SetBatchSize( + const absl::optional& batch_size) { + batch_size_ = MergeCompatible(batch_size_, batch_size); + if (batch_size_.has_value() && batch_size_.value() >= 0) { + SetMaxBatchSize(batch_size_); + } + return *this; +} + +bool ClusterBatchSize::HasBatchSize() const { return batch_size_.has_value(); } + +int ClusterBatchSize::GetBatchSize() const { + DCHECK(HasBatchSize()); + return batch_size_.value(); +} + +ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(int max_batch_size) { + SetBatchSize(static_cast>(max_batch_size)); + return *this; +} + +ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize( + const absl::optional& max_batch_size) { + max_batch_size_ = MergeCompatible(max_batch_size_, max_batch_size); + return *this; +} + +absl::optional ClusterBatchSize::GetOptionalMaxBatchSize() const { + return max_batch_size_; +} + bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) { - if (!CheckIfCompatible(static_batch_size_, other.static_batch_size_)) { + if (!CheckIfCompatible(batch_size_, other.batch_size_) || + !CheckIfCompatible(max_batch_size_, other.max_batch_size_)) { return false; } - if (other.HasStaticBatchSize()) { - static_batch_size_ = other.GetStaticBatchSize(); - } - if (other.HasDynamicBatchSize()) { - has_dynamic_batch_size_ = true; - } + + SetBatchSize(other.batch_size_); + SetMaxBatchSize(other.max_batch_size_); return true; } string ClusterBatchSize::ToString() const { string s; - absl::StrAppendFormat(&s, "batch_size=(%d,%d", HasDynamicBatchSize(), - HasStaticBatchSize()); - if (HasStaticBatchSize()) { - absl::StrAppendFormat(&s, ",%d", GetStaticBatchSize()); - } - absl::StrAppend(&s, ")"); + const auto append_optional_num = [&](const absl::optional& num) { + if (num.has_value()) { + absl::StrAppendFormat(&s, "%d", num.value()); + } else { + absl::StrAppendFormat(&s, "?"); + } + }; + absl::StrAppendFormat(&s, "batch_size="); + append_optional_num(batch_size_); + absl::StrAppendFormat(&s, ", max_batch_size="); + append_optional_num(max_batch_size_); return s; } diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index 74a20aa4a24..9a2f1e8dd5b 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ #include "absl/types/optional.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -28,30 +29,45 @@ namespace segment { // ClusterBatchSize is a data structure to record the batch size we have seen // for a cluster during segmentation. // -// When constructing clusters for implicit batch mode, we support the -// with both dynamic batch size and static batch size. We restrict nodes inside -// a cluster to either have dynamic batch size or have the same value for static -// batch size. For this reason, we use a field has_dynamic_batch_size_ to keep -// track of whether the cluster has any node with dynamic batch size. We use -// field static_batch_size_ to keep track of whether the cluster has any node -// with static batch size and what the value of the static batch size, if any. -// Examples: +// With the help of shape inference, all the dynamic batch sizes are converted +// to a negative integer number. +// If the number is -1, then nothing is known about the dynamic batch size. +// Ideally, we should not put nodes with -1 batch size into the same cluster, +// as they will likely have different batch sizes at runtime. However, we +// currently treat -1 as an equivalent class for simple implementation. We may +// need to revise this if it causes performance issues. +// If the number is strictly less than -1, then it represents a equivalent +// class. It is infered that all the nodes with the same equivalent class +// (strictly less than -1) shall have the same batch size at runtime. +// +// When constructing clusters for implicit batch mode, we support both +// dynamic batch sizes and static batch sizes. As all the nodes inside the same +// cluster shall have the same batch size at runtime, we restrict nodes inside a +// cluster to either have the same dynamic batch size equivalent class or the +// same static batch size value. +// +// Besides, all the nodes with an annotated max batch size inside the same +// cluster shall have the same annotated max batch size. (It is allowed if +// part or all the nodes inside the cluster doesn't have annotated max batch +// size). Static batch sizes are treated as max batch size annotations. The +// converter max batch size is used for an OP with a dynamic batch size and no +// annotated max batch size. +// // cluster: a = a1[1,3] + a1[1,3] -// ClusterBatchSize: has_dynamic_batch_size_ = false -// static_batch_size_ = {has value, 1} +// ClusterBatchSize: batch_size_ = 1 +// max_batch_size_ = 1 // // cluster: b = b1[-1,3] + b2[-1, 3] -// ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_size_ = {has no value} +// ClusterBatchSize: batch_size_ = -1 +// max_batch_size_ = null // -// cluster: a = a1[1,3] + a1[1,3]; b = b1[-1,3] + b2[-1, 3] -// ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_size_ = {has value, 1} +// cluster: c = c1[-2,3] + c2[-2, 3](max_batch_size=100) +// ClusterBatchSize: batch_size_ = -2 +// max_batch_size_ = 100 // // When constructing cluster for explicit batch mode, all ClusterBatchSize is // irrelevant. // -// class ClusterBatchSize { public: @@ -61,29 +77,41 @@ class ClusterBatchSize { bool operator!=(const ClusterBatchSize& other) { return !(*this == other); } // Sets the batch size assuming that the object doesn't have a batch size yet: - // a non-negative input value representing a static batch size. - // a negative input value representing a dynamic batch size. + // A non-negative input representing a static batch size value. + // A negative input representing a dynamic batch size equivalent class. ClusterBatchSize& SetBatchSize(int batch_size); - bool HasStaticBatchSize() const { return static_batch_size_.has_value(); } - int GetStaticBatchSize() const; + bool HasBatchSize() const; + int GetBatchSize() const; + // Sets the max batch size assuming that the object doesn't have a max batch + // size yet. + ClusterBatchSize& SetMaxBatchSize(int max_batch_size); + absl::optional GetOptionalMaxBatchSize() const; + + // Merge `other` into the current ClusterBatchSize if the two are not + // conflicting. Two ClusterBatchSizes are conflicting iff they both have a + // value and their values are different. bool MergeIfCompatible(const ClusterBatchSize& other); - // Returns a string for the batch size. - // If the object has a static batch size, return a string for the value. - // If the object has a dynamic size, return -1. - // Otherwise, returns -2 to represent that the batch size hasn't been set - // yet. + // Returns a string for the batch size and the annotated max batch size. + // For the batch size: + // If the object has a static batch size, return a string representing a + // non-negative integer. + // If the object has a dynamic batch size, return a string representing a + // negative integer as an equivalent class. + // If the object doesn't have a batch size yet, return "?". + // For the annotated max batch size: + // If the cluster has annotated max batch size in at least one of the nodes, + // return a string representing the annotated max batch size. Otherwise, + // return "?". std::string ToString() const; private: - bool HasDynamicBatchSize() const { return has_dynamic_batch_size_; } + ClusterBatchSize& SetBatchSize(const absl::optional& batch_size); + ClusterBatchSize& SetMaxBatchSize(const absl::optional& batch_size); - // To track whether the cluster has any node with dynamic batch size. - bool has_dynamic_batch_size_; - // To track whether the cluster has any node with static batch size, and the - // unique value for static batch size. - absl::optional static_batch_size_; + absl::optional batch_size_; + absl::optional max_batch_size_; }; inline std::ostream& operator<<(std::ostream& os, diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1cce3424fae..588b4269fee 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -196,7 +196,7 @@ filegroup( srcs = [ "xla_compiled_cpu_function.h", "//tensorflow/compiler/xla:cpu_runtime_hdrs", - "//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_hdrs", + "//tensorflow/compiler/xla/service/cpu:runtime_hdrs", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", ], @@ -208,7 +208,7 @@ filegroup( srcs = [ "xla_compiled_cpu_function.cc", "//tensorflow/compiler/xla:cpu_runtime_srcs", - "//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_srcs", + "//tensorflow/compiler/xla/service/cpu:runtime_srcs", "//tensorflow/core/kernels:xla_cpu_runtime_srcs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", ], @@ -249,6 +249,11 @@ cc_library( "//third_party/eigen3", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/platform:bfloat16", + ] + [ + # Extra dependencies required for multithreaded runtime objects. + "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:mutex", ] + tf_additional_tensor_coding_deps(), alwayslink = 1, ) @@ -742,10 +747,10 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/framework:tensor_testutil", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index a46cceddced..cfd05f18c8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -181,7 +181,7 @@ Status CompileImpl( } xla::Literal alg_literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); - Algorithm alg = Algorithm(alg_literal.Get({})); + Algorithm alg = Algorithm(alg_literal.Get({})); if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) { return errors::InvalidArgument("Unsupported algorithm id: ", alg); } diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index 09d3ef1f6d4..2f08a80e975 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -31,7 +31,12 @@ class MlirBridgePass : public MlirOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + config_proto.experimental().mlir_bridge_rollout() == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule @@ -48,7 +53,12 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + config_proto.experimental().mlir_bridge_rollout() == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 00cbe7bc9bc..471cc029a59 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -792,7 +792,7 @@ REGISTER_OP("XlaGather") .Input("slice_sizes: Tindices") .Attr("dimension_numbers: string") .Attr("indices_are_sorted: bool") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") .Output("output: T") .SetShapeFn(shape_inference::UnknownShape) @@ -813,7 +813,7 @@ REGISTER_OP("XlaScatter") .Attr("update_computation: func") .Attr("dimension_numbers: string") .Attr("indices_are_sorted: bool") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tindices: {int32, int64}") .Output("output: T") .SetShapeFn(shape_inference::UnchangedShape) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index fdd8484c249..5c8cfdde9e4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -627,8 +627,28 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.inline_with_single_device_body_placer = true; graph_optimizer_options.ignore_noinline = is_inside_mustcompile; - optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, graph_optimizer_options); + { + GraphShapeInfo shape_info; + InferShapes(graph.get(), /*arg_shapes=*/{}, + flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) + .IgnoreError(); + auto node_name_index = graph->BuildNodeNameIndex(); + std::unordered_map> shape_map; + for (const auto& node_shape_info : shape_info) { + const string& node_name = node_shape_info.first; + const std::vector& output_shapes = node_shape_info.second; + const auto& node_iter = node_name_index.find(node_name); + if (node_iter != node_name_index.end()) { + auto& partial_shapes = shape_map[node_name]; + for (const auto& inferred_shape : output_shapes) { + partial_shapes.push_back(inferred_shape.shape); + } + } + } + graph_optimizer_options.shape_map = &shape_map; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, graph_optimizer_options); + } // Run shape inference on the graph and optimize the graph again. GraphShapeInfo shape_info; @@ -734,13 +754,15 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; #ifdef LIBTPU_ON_GCE - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { VLOG(1) << "MLIR is not supported in this environment."; } TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); #else - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; std::vector control_rets; diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 410c86732d6..76cc6f0159b 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) { // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. // // This follows Cephes's f32 implementation of erf. -static XlaOp ErfImpl32(XlaOp x) { +static XlaOp ErfImpl32Cephes(XlaOp x) { // Coefficients for by erf(f32), from Cephes. // // erf(x) = x P(x^2), 0 < x < 1 @@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) { // (not surprising!), so upcast to f32 in this case. return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32(x)); + ScalarLike(x, 1) - ErfImpl32Cephes(x)); }); }); } +// Compute a polynomial approximation of the error function. +// This is the same approximation used by Eigen. +static XlaOp ErfImpl32(XlaOp x) { + static const std::array kAlpha{ + -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, + -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, + -1.60960333262415e-02f, + }; + + static const std::array kBeta{ + -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, + -7.37332916720468e-03f, -1.42647390514189e-02f, + }; + + x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f)); + auto x2 = x * x; + return x * EvaluatePolynomial(x2, kAlpha) / + EvaluatePolynomial(x2, kBeta); +} + XlaOp Erf(XlaOp x) { auto& b = *x.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { @@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x), - ScalarLike(x, 1) - ErfcImpl32(x)); - }); + return DoWithUpcastToF32(x, {BF16, F16}, + [](XlaOp x) { return ErfImpl32(x); }); }); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 41212e69b2e..b44673015bb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2324,31 +2324,53 @@ XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_RETURN_IF_ERROR( - ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()), - window_dimensions, window_strides)); + return ReduceWindow(absl::MakeSpan(&operand, 1), + absl::MakeSpan(&init_value, 1), computation, + window_dimensions, window_strides, padding); +} +XlaOp XlaBuilder::ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { + return ReportErrorOrReturn([&]() -> StatusOr { + const Shape* operand_shape = nullptr; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand)); + TF_RETURN_IF_ERROR( + ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()), + window_dimensions, window_strides)); + } + CHECK(operand_shape != nullptr); std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape->dimensions()), window_dimensions, window_strides, padding); return ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, + operands, init_values, computation, window_dimensions, window_strides, /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding) { + std::vector operand_shapes, init_shapes; return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); + for (int i = 0; i < operands.size(); ++i) { + const auto& operand = operands[i]; + const auto& init_value = init_values[i]; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + operand_shapes.push_back(operand_shape); + TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); + init_shapes.push_back(init_shape); + } TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(auto window, @@ -2358,12 +2380,33 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferReduceWindowShape( - *operand_shape, *init_shape, window, to_apply_shape)); - return ReduceWindowInternal(shape, operand, init_value, computation, + absl::MakeSpan(operand_shapes), + absl::MakeSpan(init_shapes), window, to_apply_shape)); + return ReduceWindowInternal(shape, operands, init_values, computation, std::move(window)); }); } +StatusOr XlaBuilder::ReduceWindowInternal( + const Shape& shape, absl::Span operands, + absl::Span init_values, const XlaComputation& computation, + Window window) { + if (operands.size() == 1) { + return ReduceWindowInternal(shape, operands[0], init_values[0], computation, + window); + } else { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_window() = std::move(window); + AddCalledComputation(computation, &instr); + std::vector args; + args.insert(args.end(), operands.begin(), operands.end()); + args.insert(args.end(), init_values.begin(), init_values.end()); + return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, + absl::MakeSpan(args)); + } +} + StatusOr XlaBuilder::ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window) { @@ -4067,6 +4110,17 @@ XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value, padding); } +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { + CHECK(!operands.empty()); + return operands[0].builder()->ReduceWindow(operands, init_values, computation, + window_dimensions, window_strides, + padding); +} + XlaOp ReduceWindowWithGeneralPadding( const XlaOp operand, const XlaOp init_value, const XlaComputation& computation, @@ -4076,8 +4130,9 @@ XlaOp ReduceWindowWithGeneralPadding( absl::Span window_dilations, absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, - base_dilations, window_dilations, padding); + absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f736ae1d470..05efc038082 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -648,18 +648,28 @@ class XlaBuilder { absl::Span window_dimensions, absl::Span window_strides, Padding padding); + XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding); - + StatusOr ReduceWindowInternal(const Shape& shape, + absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + Window window); virtual StatusOr ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window); - XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); @@ -1137,6 +1147,12 @@ class XlaBuilder { absl::Span window_dimensions, absl::Span window_strides, Padding padding); + friend XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, @@ -1965,6 +1981,12 @@ XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, absl::Span window_dimensions, absl::Span window_strides, Padding padding); +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index bfd13c8ddf5..4fc6c848a38 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -873,6 +873,8 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, /*window_strides=*/{1, 1, 1}, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + VLOG(2) << module->entry_computation()->root_instruction()->ToString() + << "\n"; const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE( @@ -880,6 +882,46 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { << result_shape; } +TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto p1 = Parameter(&b, 1, tuple_param_shape, "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p1, 0); + std::vector input_operands = {gte0, gte1}; + XlaBuilder bsum(TestName()); + auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0"); + auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1"); + auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0"); + auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1"); + std::vector output_operands = {Add(p2, p4), Add(p3, p5)}; + Tuple(&bsum, absl::MakeSpan(output_operands)); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + auto init = ConstantR0(&b, 0.f); + ReduceWindow(input_operands, {init, init}, sum, + /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + VLOG(2) << module->entry_computation()->root_instruction()->ToString() + << "\n"; + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(), + {true, false, false})) + << result_shape.tuple_shapes(0); + EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(), + {true, false, false})) + << result_shape.tuple_shapes(1); +} + TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 43a3860e405..1ff96db8637 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -176,6 +176,7 @@ cc_library( "//learning/brain/research/jax:__subpackages__", "//learning/deepmind/tensorflow/tensorfn:__subpackages__", "//learning/pathways:__subpackages__", + "//tensorflow/compiler/xla:friends", ], deps = [ ":local_device_state", diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index c571ef2a4df..37c4ab3b7c5 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { @@ -25,7 +26,7 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, + : PjRtDevice(id, std::move(local_device_state), /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { @@ -57,7 +58,7 @@ StatusOr> GetCpuClient(bool asynchronous) { } return std::make_unique( - kCpuPlatformName, client, std::move(devices), /*host_id=*/0, + PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index c56b41861b0..f43ec5a9216 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) { device_assignment(0, 0) = device->id(); compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - PjRtExecutable::Compile(computation, client.get(), - std::move(compile_options))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->Compile(computation, std::move(compile_options))); int64 dummy_size = 1 << 20; std::vector dummy_inputs(dummy_size); @@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) { // must wait. TF_ASSERT_OK_AND_ASSIGN( auto dummy_buffer, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( dummy_inputs.data(), dummy_shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer0, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( inputs.data(), shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer1, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( inputs.data(), shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); // The execution may be enqueued before the transfers complete, requiring // adequate device-side synchronization. ExecuteOptions options; diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 376d8687892..53a4bed8bb5 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { @@ -25,7 +26,7 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, + : PjRtDevice(id, std::move(local_device_state), /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { @@ -51,7 +52,7 @@ StatusOr> GetInterpreterClient() { devices.push_back(std::move(device)); return std::make_unique( - kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, + PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index df92921c39d..5003c8a7cde 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -30,8 +30,6 @@ limitations under the License. namespace xla { namespace { -static const char kGpuPlatformName[] = "gpu"; - // A custom PjRtClient that overrides the device assignment method. class GpuClient : public xla::PjRtClient { public: @@ -298,8 +296,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, - std::move(device_kind), node_id) {} + : PjRtDevice(id, std::move(local_device_state), std::move(device_kind), + node_id) {} StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, @@ -325,7 +323,7 @@ StatusOr> GetNvidiaGpuClient( } return std::unique_ptr(std::make_unique( - "gpu", xla_client, std::move(devices), + PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 02ae37b71db..83ed61cfe63 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -113,6 +113,13 @@ limitations under the License. namespace xla { +PjRtPlatformId PjRtDevice::platform_id() const { + return client_->platform_id(); +} +const std::string& PjRtDevice::platform_name() const { + return client_->platform_name(); +} + StatusOr PjRtDevice::GetLocalDeviceState() const { if (local_device_state_) { return local_device_state_.get(); @@ -145,8 +152,8 @@ StatusOr DevicesToDeviceAssignment( devices[replica].size(), replica, devices[0].size()); } for (int partition = 0; partition < devices[replica].size(); ++partition) { - if (devices[0][0]->platform_name() != - devices[replica][partition]->platform_name()) { + if (devices[0][0]->platform_id() != + devices[replica][partition]->platform_id()) { return InvalidArgument( "Device assignment passed to Compile() must have devices of a " "single kind, got %s for replica 0 partition 0 and %s for replica " @@ -175,13 +182,14 @@ class CpuAllocator : public tensorflow::Allocator { }; PjRtClient::PjRtClient( - std::string platform_name, LocalClient* client, + PjRtPlatformId platform_id, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options) - : platform_name_(std::move(platform_name)), + : platform_id_(platform_id), + platform_name_(Name(platform_id)), client_(client), host_memory_allocator_(std::move(host_memory_allocator)), devices_(std::move(devices)), @@ -206,15 +214,15 @@ PjRtClient::PjRtClient( CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_state()) { - int idx = device->local_device_state()->device_ordinal(); + if (device->IsLocalDevice()) { + int idx = device->local_device_id(); if (idx >= local_devices_.size()) { local_devices_.resize(idx + 1); } CHECK(local_devices_[idx] == nullptr) << idx; local_devices_[idx] = device.get(); } - device->client_ = this; + device->SetClient(this); } for (int idx = 0; idx < local_devices_.size(); ++idx) { CHECK(local_devices_[idx] != nullptr) << idx; @@ -227,62 +235,6 @@ StatusOr PjRtClient::GetDefaultDeviceAssignment( num_partitions); } -StatusOr> PjRtClient::GetParametersThatMustBeDonated( - const LocalExecutable& executable, bool tuple_inputs) const { - HloComputation* computation = - executable.executable()->module().entry_computation(); - int number_of_parameters = [&]() -> int { - if (tuple_inputs) { - CHECK_EQ(computation->num_parameters(), 1); - const Shape& input_tuple_shape = - computation->parameter_instruction(0)->shape(); - CHECK(input_tuple_shape.IsTuple()); - return input_tuple_shape.tuple_shapes_size(); - } else { - return computation->num_parameters(); - } - }(); - // If any buffer in a parameter is aliased we will donate the entire input - // parameter. - absl::flat_hash_set parameters_to_donate; - const HloInputOutputAliasConfig& config = - executable.executable()->module().input_output_alias_config(); - TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { - if (tuple_inputs) { - if (alias.parameter_number != 0) { - return InvalidArgument( - "Unexpected parameter number %d in alias config with tupled " - "inputs", - alias.parameter_number); - } - const ShapeIndex& index = alias.parameter_index; - if (!index.empty()) { - int this_parameter = index.data()[0]; - if (this_parameter >= number_of_parameters) { - return InvalidArgument( - "Unexpected parameter index %s in alias config with tupled " - "inputs and %d parameters", - index.ToString(), number_of_parameters); - } - parameters_to_donate.insert(this_parameter); - } - } else { - int this_parameter = alias.parameter_number; - if (this_parameter >= number_of_parameters) { - return InvalidArgument( - "Unexpected parameter number %d in alias config without tupled " - "inputs and %d parameters", - this_parameter, number_of_parameters); - } - parameters_to_donate.insert(this_parameter); - } - return Status::OK(); - })); - return parameters_to_donate; -} - std::unique_ptr PjRtClient::GetHloCostAnalysis() { return absl::make_unique( client_->backend().compiler()->ShapeSizeBytesFunction()); @@ -576,24 +528,25 @@ void PjRtBuffer::ScopedHold::AddToInput( } } -/* static */ -StatusOr> PjRtBuffer::FromHostBuffer( +bool PjRtBuffer::IsOnCpu() const { + return client()->platform_id() == PjRtPlatformId::kCpu; +} + +StatusOr> PjRtClient::BufferFromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, - std::shared_ptr buffer_reference, PjRtClient* client, - PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); - VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString() + std::shared_ptr buffer_reference, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer"); + VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); if (shape.IsTuple()) { - return InvalidArgument("Use FromHostLiteral to transfer a tuple"); + return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); } TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); int64 size = ShapeUtil::ByteSizeOf(shape); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); @@ -628,10 +581,11 @@ StatusOr> PjRtBuffer::FromHostBuffer( }; buffer = se::DeviceMemoryBase(const_cast(data), size); } else { - void* staging_buffer = client->host_memory_allocator()->AllocateRaw( + void* staging_buffer = host_memory_allocator()->AllocateRaw( cpu_function_runtime::kMinAlign, size); - on_delete_callback = [staging_buffer, client]() { - client->host_memory_allocator()->DeallocateRaw(staging_buffer); + on_delete_callback = [staging_buffer, host_memory_allocator = + host_memory_allocator()]() { + host_memory_allocator->DeallocateRaw(staging_buffer); }; buffer = se::DeviceMemoryBase(staging_buffer, size); std::memcpy(staging_buffer, data, size); @@ -643,7 +597,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( std::initializer_list{buffer}, definition_events, std::move(on_delete_callback)); return absl::make_unique( - shape, shape, std::move(device_buffer), client, device); + shape, shape, std::move(device_buffer), this, device); } } @@ -651,21 +605,22 @@ StatusOr> PjRtBuffer::FromHostBuffer( std::unique_ptr py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), - /*is_uninitialized_create=*/false, client)); + /*is_uninitialized_create=*/false, this)); - ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); + PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); // If necessary, allocate a host-side buffer for staging host-to-device // transfers. On GPU this is a buffer in pinned memory. std::shared_ptr staging_buffer; if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall || - client->should_stage_host_to_device_transfers()) { - void* ptr = client->host_memory_allocator()->AllocateRaw( + should_stage_host_to_device_transfers()) { + void* ptr = host_memory_allocator()->AllocateRaw( tensorflow::Allocator::kAllocatorAlignment, size); - staging_buffer = std::shared_ptr(ptr, [client](void* ptr) { - client->host_memory_allocator()->DeallocateRaw(ptr); - }); + staging_buffer = std::shared_ptr( + ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); } // Copy the buffer into a staging buffer before returning control to the @@ -684,14 +639,15 @@ StatusOr> PjRtBuffer::FromHostBuffer( // usage holds have gone away. // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. - auto transfer_h2d = [client, transfer_manager, local_device, data, size, + auto transfer_h2d = [local_client = client(), transfer_manager, local_device, + data, size, movable_device_buffer{device_buffer.ToClosure()}, shape, py_buffer{py_buffer.get()}, compact_shape, on_device_shape{py_buffer->on_device_shape()}, staging_buffer{std::move(staging_buffer)}, buffer_reference{std::move(buffer_reference)}, host_buffer_semantics]() { - ScopedHold device_buffer(movable_device_buffer); + PjRtBuffer::ScopedHold device_buffer(movable_device_buffer); // This function uses TF_CHECK_OK and ValueOrDie() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -699,7 +655,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( // allocation. ShapedBuffer buffer = device_buffer->AsShapedBuffer( - compact_shape, on_device_shape, client->client()->platform()); + compact_shape, on_device_shape, local_client->platform()); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned // memory. @@ -736,41 +692,38 @@ StatusOr> PjRtBuffer::FromHostBuffer( // already defers its work onto a stream (= thread on CPU). transfer_h2d(); } else { - client->h2d_transfer_pool()->Schedule(transfer_h2d); + h2d_transfer_pool()->Schedule(transfer_h2d); } return py_buffer; } -/* static */ -StatusOr> PjRtBuffer::CreateUninitialized( - const Shape& shape, PjRtClient* client, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized"); - VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString() - << " device: " << device->DebugString(); +StatusOr> PjRtClient::CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme( + "PjRtClient::CreateUninitializedBuffer"); + VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " + << shape.ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); return AllocateDestinationBuffer(compact_shape, device, local_device, /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/true, client); + /*is_uninitialized_create=*/true, this); } -/* static */ -StatusOr> PjRtBuffer::FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); - VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: " +StatusOr> PjRtClient::BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral"); + VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN( Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(literal.shape())); @@ -778,9 +731,9 @@ StatusOr> PjRtBuffer::FromHostLiteral( std::unique_ptr py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), - /*is_uninitialized_create=*/false, client)); + /*is_uninitialized_create=*/false, this)); - ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); + PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); // The host to device transfer is performed on a thread pool, mostly because @@ -789,11 +742,11 @@ StatusOr> PjRtBuffer::FromHostLiteral( // usage holds have gone away. // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. - auto transfer_h2d = [client, transfer_manager, local_device, + auto transfer_h2d = [local_client = client(), transfer_manager, local_device, movable_device_buffer{device_buffer.ToClosure()}, literal, py_buffer{py_buffer.get()}, compact_shape, on_device_shape{py_buffer->on_device_shape()}]() { - ScopedHold device_buffer(movable_device_buffer); + PjRtBuffer::ScopedHold device_buffer(movable_device_buffer); // This function uses TF_CHECK_OK and ValueOrDie() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -802,7 +755,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( se::Stream* h2d_stream = local_device->host_to_device_stream(); ShapedBuffer buffer = device_buffer->AsShapedBuffer( - compact_shape, on_device_shape, client->client()->platform()); + compact_shape, on_device_shape, local_client->platform()); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( h2d_stream, literal, buffer)); @@ -817,12 +770,12 @@ StatusOr> PjRtBuffer::FromHostLiteral( .IgnoreError(); // Can return error::Unimplemented QCHECK(h2d_stream->ok()); }; - client->h2d_transfer_pool()->Schedule(transfer_h2d); + h2d_transfer_pool()->Schedule(transfer_h2d); return py_buffer; } -/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( - absl::Span shapes, PjRtClient* client, PjRtDevice* device, +void PjRtClient::MakeCrossHostReceiveBuffers( + absl::Span shapes, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { notifier(InvalidArgument( @@ -843,7 +796,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( StatusOr> buffer_or = AllocateDestinationBuffer(shape, device, local_device, /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/false, client); + /*is_uninitialized_create=*/false, this); if (!buffer_or.ok()) { notifier(buffer_or.status()); return; @@ -851,7 +804,31 @@ StatusOr> PjRtBuffer::FromHostLiteral( buffers.push_back(buffer_or.ConsumeValueOrDie()); } - client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); + EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); +} + +// Transfer the given literal to the infeed queue of the given local device. +Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const { + // Only support infeed to local device. + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); + return local_device->client()->TransferToInfeedLocal( + literal, local_device->device_ordinal()); +} + +StatusOr PjRtDevice::TransferFromOutfeed(const Shape& shape) const { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); + return local_device->client()->TransferFromOutfeedLocal( + shape, local_device->device_ordinal()); +} + +StatusOr PjRtClient::LookupLocalDevice(int local_device_id) const { + for (auto* device : local_devices_) { + if (local_device_id == device->local_device_id()) { + return device; + } + } + return InvalidArgument("No matching device found for local_device_id %d", + local_device_id); } PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, @@ -1159,7 +1136,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, StatusOr> PjRtBuffer::ToLiteral( const bool discard_cached_copy, absl::optional layout) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); + tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral"); TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { @@ -1267,9 +1244,9 @@ StatusOr> PjRtBuffer::CopyToDevice( // Copying across PjRtClients involves a copy through the host. if (dst_device->client() != client_) { TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteral()); - return FromHostBuffer(literal->untyped_data(), literal->shape(), - HostBufferSemantics::kZeroCopy, nullptr, - dst_device->client(), dst_device); + return dst_device->client()->BufferFromHostBuffer( + literal->untyped_data(), literal->shape(), + PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); } TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, @@ -1498,12 +1475,66 @@ PjRtExecutable::PjRtExecutable( } } -Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) { +StatusOr> GetParametersThatMustBeDonated( + const HloModule& module, bool tuple_inputs) { + HloComputation* computation = module.entry_computation(); + int number_of_parameters = [&]() -> int { + if (tuple_inputs) { + CHECK_EQ(computation->num_parameters(), 1); + const Shape& input_tuple_shape = + computation->parameter_instruction(0)->shape(); + CHECK(input_tuple_shape.IsTuple()); + return input_tuple_shape.tuple_shapes_size(); + } else { + return computation->num_parameters(); + } + }(); + // If any buffer in a parameter is aliased we will donate the entire input + // parameter. + absl::flat_hash_set parameters_to_donate; + const HloInputOutputAliasConfig& config = module.input_output_alias_config(); + TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (tuple_inputs) { + if (alias.parameter_number != 0) { + return InvalidArgument( + "Unexpected parameter number %d in alias config with tupled " + "inputs", + alias.parameter_number); + } + const ShapeIndex& index = alias.parameter_index; + if (!index.empty()) { + int this_parameter = index.data()[0]; + if (this_parameter >= number_of_parameters) { + return InvalidArgument( + "Unexpected parameter index %s in alias config with tupled " + "inputs and %d parameters", + index.ToString(), number_of_parameters); + } + parameters_to_donate.insert(this_parameter); + } + } else { + int this_parameter = alias.parameter_number; + if (this_parameter >= number_of_parameters) { + return InvalidArgument( + "Unexpected parameter number %d in alias config without tupled " + "inputs and %d parameters", + this_parameter, number_of_parameters); + } + parameters_to_donate.insert(this_parameter); + } + return Status::OK(); + })); + return parameters_to_donate; +} + +Status PjRtExecutable::SetUpDonation(bool tuple_inputs) { parameters_that_must_be_donated_.reserve(executables_.size()); for (auto& executable : executables_) { - TF_ASSIGN_OR_RETURN( - absl::flat_hash_set parameters_to_donate, - client->GetParametersThatMustBeDonated(*executable, tuple_inputs)); + TF_ASSIGN_OR_RETURN(absl::flat_hash_set parameters_to_donate, + GetParametersThatMustBeDonated( + executable->executable()->module(), tuple_inputs)); parameters_that_must_be_donated_.emplace_back( std::move(parameters_to_donate)); } @@ -1974,6 +2005,19 @@ PjRtExecutable::ExecuteOnLocalDevices( return wrapped_results; } +StatusOr>> +PjRtExecutable::GetHloModules() { + std::vector> modules; + modules.reserve(executables().size()); + for (const auto& local_exec : executables()) { + if (!local_exec->executable()->has_module()) { + return InvalidArgument("Executable does not have HLO modules."); + } + modules.push_back(local_exec->executable()->shared_module()); + } + return std::move(modules); +} + namespace { StatusOr GetShardedShape(const Shape& shape, @@ -2061,14 +2105,13 @@ StatusOr, Shape>> GetShardedProgramShapes( } // namespace -/*static*/ StatusOr> PjRtExecutable::Compile( - const XlaComputation& computation, PjRtClient* client, - CompileOptions options) { - tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); +StatusOr> PjRtClient::Compile( + const XlaComputation& computation, CompileOptions options) { + tensorflow::profiler::TraceMe traceme("PjRtClient::Compile"); ExecutableBuildOptions& build_options = options.executable_build_options; if (!build_options.device_allocator()) { - build_options.set_device_allocator(client->allocator()); + build_options.set_device_allocator(allocator()); } int num_replicas; @@ -2084,14 +2127,14 @@ StatusOr, Shape>> GetShardedProgramShapes( num_partitions = 1; } else { if (!build_options.has_device_assignment()) { - VLOG(2) << "PjRtExecutable::Compile using default device_assignment."; + VLOG(2) << "PjRtClient::Compile using default device_assignment."; TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, - client->GetDefaultDeviceAssignment(build_options.num_replicas(), - build_options.num_partitions())); + GetDefaultDeviceAssignment(build_options.num_replicas(), + build_options.num_partitions())); build_options.set_device_assignment(device_assignment); } - VLOG(2) << "PjRtExecutable::Compile device_assignment:\n" + VLOG(2) << "PjRtClient::Compile device_assignment:\n" << build_options.device_assignment().ToString(); num_replicas = build_options.device_assignment().replica_count(); num_partitions = build_options.device_assignment().computation_count(); @@ -2118,7 +2161,8 @@ StatusOr, Shape>> GetShardedProgramShapes( // Assign a default layout based on `sharded_shape` to any array subshapes in // `dst_shape` that are missing layouts. - auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) { + auto assign_layouts = [local_client = client()](const Shape& sharded_shape, + Shape* dst_shape) { return ShapeUtil::ForEachMutableSubshapeWithStatus( dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { if (subshape->IsArray() && !subshape->has_layout()) { @@ -2126,8 +2170,7 @@ StatusOr, Shape>> GetShardedProgramShapes( const Shape& sharded_subshape = ShapeUtil::GetSubshape(sharded_shape, idx); LayoutUtil::SetToDefaultLayout(subshape); - TF_ASSIGN_OR_RETURN(Shape layout, client->client() - ->backend() + TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend() .transfer_manager() ->ChooseCompactLayoutForShape( sharded_subshape)); @@ -2162,8 +2205,8 @@ StatusOr, Shape>> GetShardedProgramShapes( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); - PjRtDevice* device = LookupDevice(*client, device_id); - if (device->host_id() != client->host_id()) { + PjRtDevice* device = LookupDevice(*this, device_id); + if (device->host_id() != host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; } @@ -2185,15 +2228,14 @@ StatusOr, Shape>> GetShardedProgramShapes( TF_ASSIGN_OR_RETURN( std::vector> local_executables, - client->client()->Compile(computation, argument_layout_pointers, - build_options)); + client()->Compile(computation, argument_layout_pointers, build_options)); auto executable = absl::make_unique( std::move(local_executables), options.parameter_is_tupled_arguments, std::move(device_assignment), std::move(local_logical_device_ids), - std::move(local_devices), client); + std::move(local_devices), this); TF_RETURN_IF_ERROR( - executable->SetUpDonation(client, options.parameter_is_tupled_arguments)); + executable->SetUpDonation(options.parameter_is_tupled_arguments)); return executable; } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index cb4ef9da85b..86805182525 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -50,26 +52,63 @@ limitations under the License. namespace xla { +// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms. +enum class PjRtPlatformId : int { + kCpu = 0, + kNvidiaGpu = 1, + kAmdGpu = 2, + kTpu = 3, + kEdgeTpu = 4, + kInterpreter = 5 +}; +constexpr const char* Name(PjRtPlatformId platform_id) { + switch (platform_id) { + case PjRtPlatformId::kCpu: + return "cpu"; + case PjRtPlatformId::kNvidiaGpu: + // TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support. + return "gpu"; + case PjRtPlatformId::kAmdGpu: + return "amd_gpu"; + case PjRtPlatformId::kTpu: + return "tpu"; + case PjRtPlatformId::kEdgeTpu: + return "edge_tpu"; + case PjRtPlatformId::kInterpreter: + return "interpreter"; + } +} + class PjRtClient; class PjRtDevice { public: explicit PjRtDevice(int id, std::unique_ptr local_device_state, - std::string platform_name, std::string device_kind, - int host_id = 0) + std::string device_kind, int host_id = 0) : id_(id), + local_device_id_( + local_device_state ? local_device_state->device_ordinal() : -1), local_device_state_(std::move(local_device_state)), host_id_(host_id), - platform_name_(std::move(platform_name)), device_kind_(std::move(device_kind)) {} virtual ~PjRtDevice() {} + // Must set client exactly once. + void SetClient(PjRtClient* client) { + CHECK(client_ == nullptr); + client_ = client; + } + // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // hosts' devices. This is the ID that should be used in a DeviceAssignment. int id() const { return id_; } + bool IsLocalDevice() const { return local_device_id_ != -1; } + + int local_device_id() const { return local_device_id_; } + // If this is a device local to this host, returns a LocalDeviceState object // that can be used to manipulate the device. Returns nullptr if the device is // not local to this host. @@ -85,7 +124,11 @@ class PjRtDevice { // The ID of this device's host. This is always 0 on single-host platforms. int host_id() const { return host_id_; } - const std::string& platform_name() const { return platform_name_; } + // Return `platform_id` from client. + PjRtPlatformId platform_id() const; + + // Return `platform_name` from client. + const std::string& platform_name() const; // A vendor-dependent string that uniquely identifies the kind of device. const std::string& device_kind() const { return device_kind_; } @@ -94,13 +137,18 @@ class PjRtDevice { PjRtClient* client() const { return client_; } - private: - friend class PjRtClient; + // Transfer the given literal to the infeed queue of the given localdevice. + virtual Status TransferToInfeed(const LiteralSlice& literal) const; + // Transfer and return a value of the given shape from the outfeed of the + // given device. + virtual StatusOr TransferFromOutfeed(const Shape& shape) const; + + private: const int id_; + const int local_device_id_; // -1 means not local. const std::unique_ptr local_device_state_; const int host_id_; - const std::string platform_name_; const std::string device_kind_; PjRtClient* client_ = nullptr; }; @@ -120,6 +168,24 @@ struct PjRtCrossHostRecvBuffer { using PjRtCrossHostRecvNotifier = std::function>&&)>; +struct CompileOptions { + // The layouts of the arguments that the computation should expect. + absl::optional> argument_layouts; + + // If true, the supplied computation expects its arguments to be wrapped in a + // tuple and passed as a single parameter. + bool parameter_is_tupled_arguments = false; + + // XLA's compilation time options. + ExecutableBuildOptions executable_build_options; + + // If true, the executable can be run on any device. May only be true if + // !executable_build_options.has_device_assignment(), so only applies to + // single-device executables. Beware: on GPUs, sometimes an executable + // compiled for one device doesn't run on another. + bool compile_portable_executable = false; +}; + class PjRtExecutable; // Encapsulates the state of Python session with XLA. @@ -130,7 +196,7 @@ class PjRtClient { public: // `allocator` may null, in which case the platform default allocator is used. explicit PjRtClient( - std::string platform_name, LocalClient* client, + PjRtPlatformId platform_id, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, @@ -153,12 +219,16 @@ class PjRtClient { return id_to_device_; } int host_id() const { return host_id_; } + PjRtPlatformId platform_id() const { return platform_id_; } const std::string& platform_name() const { return platform_name_; } LocalDeviceState& device_state(int device_ordinal) const { return *local_devices_.at(device_ordinal)->local_device_state(); } + // Return a local PjRtDevice for a given `local_device_id`. + virtual StatusOr LookupLocalDevice(int local_device_id) const; + LocalClient* client() const { return client_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; } tensorflow::Allocator* host_memory_allocator() const { @@ -181,13 +251,6 @@ class PjRtClient { // function specifies which one the platform expects. virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } - // Some platforms allow executables to donate buffers so that they can be - // aliased from inputs to outputs. This function returns the list of - // parameters that must be donated when executable is run. tuple_inputs - // reflects the option that executable was compiled with. - virtual StatusOr> GetParametersThatMustBeDonated( - const LocalExecutable& executable, bool tuple_inputs) const; - // Generates a unique fingerprint for `executable`. See // PjRtExecutable::fingerprint_. virtual StatusOr> ExecutableFingerprint( @@ -198,6 +261,73 @@ class PjRtClient { // Returns a backend-specific HLO cost analysis visitor. virtual std::unique_ptr GetHloCostAnalysis(); + virtual StatusOr> Compile( + const XlaComputation& computation, CompileOptions options); + + virtual StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device); + + // Describes the semantics the caller to BufferFromHostBuffer expects from the + // runtime, in a total order from most restrictive to least restrictive. + enum class HostBufferSemantics { + // The runtime may not hold references to `data` after the call to + // `BufferFromHostBuffer` completes. The caller promises that `data` is + // immutable and will not be freed only for the duration of the + // BufferFromHostBuffer call. `buffer_reference` will be freed by the time + // `BufferFromHostBuffer` returns. + kImmutableOnlyDuringCall, + + // The runtime may hold onto `data` after the call to `BufferFromHostBuffer` + // returns while the runtime completes a transfer to the device. The caller + // promises not to mutate or free `data` until the transfer completes, at + // which point the runtime will release `buffer_reference`. It is also + // correct to wait on the host (directly or indirectly) for the buffer's + // definition event to complete. + kImmutableUntilTransferCompletes, + + // The PjRtBuffer may alias `data` internally and the runtime may use the + // `data` contents as long as the buffer is alive. The caller promises to + // keep `data` alive and not to mutate its contents as long as the buffer is + // alive; to notify the caller that the buffer may be freed, the runtime + // will release its `buffer_reference` when the PjRtBuffer is freed. On + // non-CPU platforms this acts identically to + // kImmutableUntilTransferCompletes. + kZeroCopy, + }; + virtual StatusOr> BufferFromHostBuffer( + const void* data, const Shape& shape, + HostBufferSemantics host_buffer_semantics, + std::shared_ptr buffer_reference, PjRtDevice* device); + + // Note that literal must remain in scope until the transfer has completed, so + // the caller should, for example, wait for BlockHostUntilReady() completes on + // the return value before letting literal go out of scope. + virtual StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device); + + // Asynchronously makes a vector of PjRtBuffers that can be used to receive + // cross host transfers using `client` on `device'. `shapes` must be the exact + // shapes, with identical layouts, corresponding to the buffers that will be + // sent. When resources for the transfer are available, notifier will be + // called with a vector of PjRtCrossHostRecvBuffer structs, one for each + // shape in `shapes`. Each struct contains a buffer that will contain the + // received value, and an opaque string that should be transmitted to the + // sending host and used in a call to CopyToRemoteDevice. None of the recv + // buffers will become ready until *all* of the sends have completed. + virtual void MakeCrossHostReceiveBuffers( + absl::Span shapes, PjRtDevice* device, + PjRtCrossHostRecvNotifier&& notifier); + + virtual StatusOr CreateChannelHandle() { + return client()->CreateChannelHandle(); + } + virtual StatusOr CreateDeviceToHostChannelHandle() { + return client()->CreateDeviceToHostChannelHandle(); + } + virtual StatusOr CreateHostToDeviceChannelHandle() { + return client()->CreateHostToDeviceChannelHandle(); + } + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -211,7 +341,8 @@ class PjRtClient { return Unimplemented("Cross host sends not implemented."); } - std::string platform_name_; + const PjRtPlatformId platform_id_; + const std::string platform_name_; LocalClient* client_; // Allocator to be used for staging memory transfers to devices. @@ -385,6 +516,7 @@ class PjRtBuffer { private: friend class PjRtBuffer; + friend class PjRtClient; // Helper struct that makes it possible to move a ScopedHold through a // closure. @@ -423,66 +555,10 @@ class PjRtBuffer { StatusOr> buffer_or_; }; - // Returns a buffer with uninitialized contents. - static StatusOr> CreateUninitialized( - const Shape& shape, PjRtClient* client, PjRtDevice* device); - - // Describes the semantics the caller to FromHostBuffer expects from the - // runtime, in a total order from most restrictive to least restrictive. - enum class HostBufferSemantics { - // The runtime may not hold references to `data` after the call to - // `FromHostBuffer` completes. The caller promises that `data` is immutable - // and will not be freed only for the duration of the FromHostBuffer call. - // `buffer_reference` will be freed by the time `FromHostBuffer` returns. - kImmutableOnlyDuringCall, - - // The runtime may hold onto `data` after the call to `FromHostBuffer` - // returns while the runtime completes a transfer to the device. The caller - // promises not to mutate or free `data` until the transfer completes, at - // which point the runtime will release `buffer_reference`. It is also - // correct to wait on the host (directly or indirectly) for the buffer's - // definition event to complete. - kImmutableUntilTransferCompletes, - - // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. - // The caller promises to keep `data` alive and not to mutate its contents - // as long as the buffer is alive; to notify the caller that the buffer may - // be freed, the runtime will release its `buffer_reference` when the - // PjRtBuffer is freed. On non-CPU platforms this acts identically to - // kImmutableUntilTransferCompletes. - kZeroCopy, - }; - static StatusOr> FromHostBuffer( - const void* data, const Shape& shape, - HostBufferSemantics host_buffer_semantics, - std::shared_ptr buffer_reference, PjRtClient* client, - PjRtDevice* device); - - // Note that literal must remain in scope until the transfer has completed, so - // the caller should, for example, wait for BlockHostUntilReady() completes on - // the return value before letting literal go out of scope. - static StatusOr> FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device); - - // Asynchronously makes a vector of PjRtBuffers that can be used to receive - // cross host transfers using `client` on `device'. `shapes` must be the exact - // shapes, with identical layouts, corresponding to the buffers that will be - // sent. When resources for the transfer are available, notifier will be - // called with a vector of PjRtCrossHostRecvBuffer structs, one for each - // shape in `shapes`. Each struct contains a buffer that will contain the - // received value, and an opaque string that should be transmitted to the - // sending host and used in a call to CopyToRemoteDevice. None of the recv - // buffers will become ready until *all* of the sends have completed. - static void MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtClient* client, - PjRtDevice* device, - PjRtCrossHostRecvNotifier&& notifier); - PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, PjRtClient* client, PjRtDevice* device); - ~PjRtBuffer(); + virtual ~PjRtBuffer(); PjRtBuffer(const PjRtBuffer&) = delete; PjRtBuffer(PjRtBuffer&&) = delete; @@ -492,6 +568,7 @@ class PjRtBuffer { const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } PjRtDevice* device() const { return device_; } + PjRtPlatformId platform_id() const { return client_->platform_id(); } const std::string& platform_name() const { return client_->platform_name(); } PjRtClient* client() const { return client_; } bool IsEmptyTuple() const { @@ -584,6 +661,9 @@ class PjRtBuffer { // immediate use on the device. Useful in particular for timing benchmarks. Status BlockHostUntilReady(); + // Whether this buffer is on CPU and thus allows for certain optimizations. + bool IsOnCpu() const; + private: friend class PjRtClient; // The cached value of the buffer on the host, produced either from a call to @@ -661,24 +741,6 @@ class PjRtBuffer { Semaphore donation_semaphore_; }; -struct CompileOptions { - // The layouts of the arguments that the computation should expect. - absl::optional> argument_layouts; - - // If true, the supplied computation expects its arguments to be wrapped in a - // tuple and passed as a single parameter. - bool parameter_is_tupled_arguments = false; - - // XLA's compilation time options. - ExecutableBuildOptions executable_build_options; - - // If true, the executable can be run on any device. May only be true if - // !executable_build_options.has_device_assignment(), so only applies to - // single-device executables. Beware: on GPUs, sometimes an executable - // compiled for one device doesn't run on another. - bool compile_portable_executable = false; -}; - class ExecuteContext { public: virtual ~ExecuteContext() = default; @@ -710,10 +772,6 @@ struct ExecuteOptions { // buffer will be donated when passed to the execution. class PjRtExecutable { public: - static StatusOr> Compile( - const XlaComputation& computation, PjRtClient* client, - CompileOptions options); - PjRtExecutable(std::vector> executables, bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, @@ -777,15 +835,19 @@ class PjRtExecutable { const string& name() const; + // Return an HloModule per partition. + StatusOr>> GetHloModules(); + protected: bool parameter_is_tupled_arguments() const { return parameter_is_tupled_arguments_; } private: + friend class PjRtClient; // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. - Status SetUpDonation(PjRtClient* client, bool tuple_inputs); + Status SetUpDonation(bool tuple_inputs); virtual bool MustDonateParameter(int executable_idx, int parameter) const; @@ -844,6 +906,13 @@ class PjRtExecutable { std::vector local_devices_; }; +// Executables can donate buffers so that buffers can be aliased from inputs +// to outputs. This function returns the list of parameters that must be +// donated when executable is run. tuple_inputs reflects the option that +// executable was compiled with. +StatusOr> GetParametersThatMustBeDonated( + const HloModule& hlo_module, bool tuple_inputs); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index a8711631605..5a28d82335e 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client, std::vector> devices, int host_id, tf_tpu::TpuPlatformInterface* tpu_platform) - : PjRtClient("tpu", client, std::move(devices), host_id, + : PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, @@ -145,7 +145,7 @@ StatusOr> PjRtTpuClient::ExecutableFingerprint( return InvalidArgument( "Passed executable from different client (platform '%s') to " "PjRtTpuClient::ExecutableFingerprint", - executable.client()->platform_name()); + Name(executable.client()->platform_id())); } if (executable.executables().size() > 1) { LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " @@ -199,7 +199,8 @@ StatusOr>> GetTpuDevices( StatusOr> GetTpuClient( bool asynchronous, absl::Duration init_retry_timeout) { tf_tpu::TpuPlatformInterface* platform = - tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform( + /*initialize_platform=*/true, /*num_tries=*/1); if (platform == nullptr) { return InvalidArgument("TpuPlatform is not available."); } diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index 1a458c1480b..cdc68bc9606 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -33,7 +33,7 @@ class PjRtTpuDevice : public PjRtDevice { int host_id, const std::array& coords, std::string device_kind) : PjRtDevice(core.Id(), std::move(local_device_state), - /*platform_name=*/"tpu", std::move(device_kind), host_id), + std::move(device_kind), host_id), core_(core), coords_(coords) {} diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index fb9c705fdf3..a3cba5dc44b 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -28,6 +28,18 @@ pyx_library( srcs = ["custom_call_for_test.pyx"], ) +py_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + python_version = "PY3", + tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic. + deps = [ + ":xla_client", + ":xla_extension", + "@absl_py//absl/testing:absltest", + ] + xla_py_test_deps(), +) + py_library( name = "xla_client_test", testonly = 1, @@ -227,6 +239,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_client", + ":python_ref_manager", ":traceback", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -434,6 +447,7 @@ pybind_extension( "//tensorflow/compiler/xla/pjrt:interpreter_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tpu_client", "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/compiler/xla/pjrt/distributed", "//tensorflow/compiler/xla/pjrt/distributed:client", diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index b70244cc3ef..5dcfc3b0dcc 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -396,25 +396,30 @@ PyTypeObject PyBfloat16_Type = { PyArray_ArrFuncs NPyBfloat16_ArrFuncs; PyArray_Descr NPyBfloat16_Descr = { - PyObject_HEAD_INIT(nullptr) & PyBfloat16_Type, // typeobj + PyObject_HEAD_INIT(nullptr) // + /*typeobj=*/ + (&PyBfloat16_Type), // We must register bfloat16 with a kind other than "f", because numpy // considers two types with the same kind and size to be equal, but // float16 != bfloat16. // The downside of this is that NumPy scalar promotion does not work with // bfloat16 values. - 'V', // kind + /*kind=*/'V', // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type // character is unique. - 'E', // type - '=', // byteorder - NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, // hasobject - 0, // type_num - sizeof(bfloat16), // elsize - alignof(bfloat16), // alignment - nullptr, // subarray - nullptr, // fields - nullptr, // names - &NPyBfloat16_ArrFuncs, // f + /*type=*/'E', + /*byteorder=*/'=', + /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, + /*type_num=*/0, + /*elsize=*/sizeof(bfloat16), + /*alignment=*/alignof(bfloat16), + /*subarray=*/nullptr, + /*fields=*/nullptr, + /*names=*/nullptr, + /*f=*/&NPyBfloat16_ArrFuncs, + /*metadata=*/nullptr, + /*c_metadata=*/nullptr, + /*hash=*/-1, // -1 means "not computed yet". }; // Implementations of NumPy array methods. diff --git a/tensorflow/compiler/xla/python/bfloat16_test.py b/tensorflow/compiler/xla/python/bfloat16_test.py index 60b56bf810d..9aaa955d546 100644 --- a/tensorflow/compiler/xla/python/bfloat16_test.py +++ b/tensorflow/compiler/xla/python/bfloat16_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import copy import itertools import math @@ -254,6 +255,15 @@ class Bfloat16NumPyTest(parameterized.TestCase): def testDtype(self): self.assertEqual(bfloat16, np.dtype(bfloat16)) + def testDeepCopyDoesNotAlterHash(self): + # For context, see https://github.com/google/jax/issues/4651. If the hash + # value of the type descriptor is not initialized correctly, a deep copy + # can change the type hash. + dtype = np.dtype(bfloat16) + h = hash(dtype) + _ = copy.deepcopy(dtype) + self.assertEqual(h, hash(dtype)) + def testArray(self): x = np.array([[1, 2, 3]], dtype=bfloat16) self.assertEqual(bfloat16, x.dtype) diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 67afa25d23e..8f4045a0e7c 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -23,8 +23,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack +#include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/traceback.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -41,12 +43,30 @@ namespace { const char* const kDlTensorCapsuleName = "dltensor"; struct DLPackTensor { + ~DLPackTensor(); + + // At most one of buffer and buffer_reference/scoped_hold is populated. + + // `buffer` is populated if we have exclusive (read-write) access. std::shared_ptr buffer; + + // `buffer_reference` and `scoped_hold` are populated if we have + // shared (read-only) access. + py::object buffer_reference; + absl::optional scoped_hold; + std::vector shape; std::vector strides; DLManagedTensor tensor; }; +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + void DLPackTensorDeleter(DLManagedTensor* t) { if (t) { delete static_cast(t->manager_ctx); @@ -208,68 +228,76 @@ StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); - context.device_id = device.local_device_state()->device_ordinal(); + context.device_id = device.local_device_id(); return context; } StatusOr DeviceForDLContext(const PjRtClient& client, const DLContext& context) { - se::Platform::Id platform_id; switch (context.device_type) { case kDLCPU: - platform_id = se::host::kHostPlatformId; - break; + if (client.platform_id() != PjRtPlatformId::kCpu) { + return InvalidArgument( + "DLPack CPU device type mismatch with PjRtClient platform %s", + client.platform_name()); + } + return client.LookupLocalDevice(context.device_id); case kDLGPU: - platform_id = se::cuda::kCudaPlatformId; - break; + if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) { + return InvalidArgument( + "DLPack GPU device type mismatch with PjRtClient platform %s", + client.platform_name()); + } + return client.LookupLocalDevice(context.device_id); default: return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); } - auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) { - return device->local_device_state()->executor()->platform()->id() == - platform_id && - device->local_device_state()->device_ordinal() == context.device_id; - }); - if (it == client.local_devices().end()) { - return InvalidArgument( - "No matching device found for DLPack device_type %d device_id %d", - context.device_type, context.device_id); - } - return *it; } } // namespace -StatusOr BufferToDLPackManagedTensor(PyBuffer* buffer) { +StatusOr BufferToDLPackManagedTensor(py::handle py_buffer, + bool take_ownership) { + PyBuffer* buffer = py::cast(py_buffer); auto pack = std::make_unique(); - // Block on outstanding operations, so that it is safe to read or mutate the - // returned buffer. - StatusOr> buffer_or = - buffer->buffer()->Release(/*wait_for_operations_to_complete=*/true); - if (!buffer_or.ok()) { - return InvalidArgument( - "Buffer synchronization failed converting to DLPack tensor: %s", - buffer_or.status().ToString()); - } - pack->buffer = buffer_or.ConsumeValueOrDie(); - if (!pack->buffer) { - return InvalidArgument( - "Cannot convert deleted/invalid buffer to DLPack tensor."); - } - pack->tensor.manager_ctx = pack.get(); - pack->tensor.deleter = DLPackTensorDeleter; - DLTensor& dt = pack->tensor.dl_tensor; if (buffer->buffer()->on_device_shape().IsTuple()) { return Unimplemented( "unsafe_buffer_pointer is not implemented for tuple " "buffers."); } - TF_RET_CHECK(pack->buffer->device_memory().size() == 1); - dt.data = pack->buffer->device_memory().front().opaque(); + + DLTensor& dt = pack->tensor.dl_tensor; + if (take_ownership) { + // Block on outstanding operations, so that it is safe to read or mutate the + // returned buffer. + StatusOr> buffer_or = + buffer->buffer()->Release(/*wait_for_operations_to_complete=*/true); + if (!buffer_or.ok()) { + return InvalidArgument( + "Buffer synchronization failed converting to DLPack tensor: %s", + buffer_or.status().ToString()); + } + pack->buffer = buffer_or.ConsumeValueOrDie(); + if (!pack->buffer) { + return InvalidArgument( + "Cannot convert deleted/invalid buffer to DLPack tensor."); + } + TF_RET_CHECK(pack->buffer->device_memory().size() == 1); + dt.data = pack->buffer->device_memory().front().opaque(); + } else { + // Block on outstanding operations, so that it is safe to read or mutate the + // returned buffer. + TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady()); + pack->buffer_reference = py::reinterpret_borrow(py_buffer); + pack->scoped_hold.emplace( + buffer->buffer()->GetBufferWithExternalReference()); + dt.data = pack->scoped_hold->buffer()->device_memory().front().opaque(); + } + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); - dt.ctx.device_id = - buffer->buffer()->device()->local_device_state()->device_ordinal(); + dt.ctx.device_id = buffer->buffer()->device()->local_device_id(); dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 7200997cf27..c39e6a7d932 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -22,7 +22,11 @@ limitations under the License. namespace xla { -StatusOr BufferToDLPackManagedTensor(PyBuffer* buffer); +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +StatusOr BufferToDLPackManagedTensor(pybind11::handle buffer, + bool take_ownership); StatusOr> DLPackManagedTensorToBuffer( const pybind11::capsule& tensor, std::shared_ptr client); diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 2d392d41f37..b7d833e5948 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -130,11 +130,28 @@ struct CallSignature { PjRtDevice* device; bool operator==(const CallSignature& other) const { - return std::tie(dynamic_positional_args_treedef, static_args, keyword_args, + return std::tie(dynamic_positional_args_treedef, keyword_args, dynamic_args_signatures, device) == - std::tie(other.dynamic_positional_args_treedef, other.static_args, - other.keyword_args, other.dynamic_args_signatures, - other.device); + std::tie(other.dynamic_positional_args_treedef, + other.keyword_args, other.dynamic_args_signatures, + other.device) && + // `==` on py:objects is the Python `is`. We need equal. + std::equal( + static_args.begin(), static_args.end(), + other.static_args.begin(), other.static_args.end(), + [](const py::object& a, const py::object& b) { + try { + return a.equal(b); + } catch (const py::error_already_set& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two " + "objects of types ", + py::cast(py::str(py::type::of(a))), " and ", + py::cast(py::str(py::type::of(b))), + ". The error was:\n", e.what())); + } + }); } bool operator!=(const CallSignature& other) const { return !(*this == other); @@ -169,12 +186,6 @@ H AbslHashValue(H h, const CallSignature::KwargEntry& kw) { template H AbslHashValue(H h, const CallSignature& s) { - // /!\ important: We cannot include static arguments to the hash, because - // the py::object must be hashable for absl. We can try delegating to the - // Python __hash__, but there are many non-hashable Python types such as - // np.ndarray. - // TODO(jblespiau): We should either ban non-hashable objects from jit or we - // should hash them by object identity. h = H::combine_contiguous(std::move(h), s.dynamic_positional_args_treedef.data(), s.dynamic_positional_args_treedef.size()); @@ -183,6 +194,20 @@ H AbslHashValue(H h, const CallSignature& s) { h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(), s.dynamic_args_signatures.size()); h = H::combine(std::move(h), s.device); + for (const auto& static_arg : s.static_args) { + ssize_t hash; + try { + hash = py::hash(static_arg); + } catch (const py::error_already_set& e) { + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occured " + "while trying to hash an object of type ", + py::cast(py::str(py::type::of(static_arg))), ", ", + py::cast(py::str(static_arg)), ". The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } return h; } @@ -269,6 +294,8 @@ class CompiledFunction { return inspect->attr("signature")(fun_); } + int cache_size() const { return executables_.size(); } + private: // Returns nullptr if not present in the cache. CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature); @@ -457,10 +484,10 @@ std::unique_ptr ConvertToScalarBuffer( xla::PjRtDevice* device) { CppType data = py::cast(scalar); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, - client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + device)); } // Convert a scalar to the associated PjRtBuffer or raises an error if it is @@ -494,17 +521,17 @@ StatusOr> ScalarToBuffer( if (jax_enable_x64) { xla::complex128 data(result.real, result.imag); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, - nullptr, client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, device)); } else { xla::complex64 data(result.real, result.imag); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, - nullptr, client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, device)); } } return InvalidArgument( @@ -670,7 +697,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, ValueOrThrow(pyclient.BufferFromPyval( numpy_array, data_device, /*force_copy=*/false, /*host_buffer_semantics=*/ - xla::PjRtBuffer::HostBufferSemantics::kZeroCopy)); + xla::PjRtClient::HostBufferSemantics::kZeroCopy)); arg_buffers.push_back(buffer->buffer()); ArgSignature sig; @@ -789,6 +816,10 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { } default_pyclient_ = default_pydevice_.client; default_device_ = default_pydevice_.contents; + if (!default_device_) { // UPTC + always_fallback_to_python_ = true; + return py::cast(cache_miss_(*args, **kwargs))[0]; + } is_committed_ = py::cast(device_and_is_committed.attr("committed_to_device")); } @@ -866,6 +897,7 @@ void BuildJaxjitSubmodule(pybind11::module& m) { }); // Only for testing purposes + cfun.def("_cache_size", &CompiledFunction::cache_size); jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object { py::dtype dtype = py::dtype::from_args(obj); const py::dtype* res = DtypeTo32BitDtype(dtype); diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index f6067e650c0..2535d62ee7e 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - PjRtExecutable::Compile(computation, devices_[device_idx]->client(), - std::move(compile_options))); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + devices_[device_idx]->client()->Compile( + computation, std::move(compile_options))); ExecuteOptions execute_options; TF_ASSIGN_OR_RETURN(std::vector> output_buffers, executable->Execute({}, execute_options)); diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index 919dafe2e0b..5422a4b3056 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id, compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - PjRtExecutable::Compile(computation, client, std::move(compile_options))); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + client->Compile(computation, std::move(compile_options))); ExecuteOptions execute_options; TF_ASSIGN_OR_RETURN(std::vector> output_buffers, executable->Execute({}, execute_options)); diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index b32fe047530..cac14142b75 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -144,7 +144,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { // Additionally we call BlockHostUntilReady() below, which may block. py::gil_scoped_release gil_release; - if (buffer.device()->platform_name() != "cpu") { + if (!buffer.IsOnCpu()) { return InvalidArgument( "Python buffer protocol is only defined for CPU buffers."); } diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 07b915c640c..d42bbdca154 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { StatusOr> PyClient::BufferFromPyval( const pybind11::object& argument, PjRtDevice* device, bool force_copy, - PjRtBuffer::HostBufferSemantics host_buffer_semantics) { + PjRtClient::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { TF_RET_CHECK(!pjrt_client_->local_devices().empty()); device = pjrt_client_->local_devices().front(); @@ -114,10 +114,9 @@ StatusOr> PyClient::BufferFromPyval( std::unique_ptr buffer; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - buffer, PjRtBuffer::FromHostBuffer( - c->buf_ptr, c->shape, host_buffer_semantics, - std::move(py_buffer_ref), pjrt_client_.get(), device)); + TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer( + c->buf_ptr, c->shape, host_buffer_semantics, + std::move(py_buffer_ref), device)); } auto traceback = Traceback::Get(); return std::make_unique(shared_from_this(), std::move(buffer), @@ -131,8 +130,7 @@ StatusOr> PyClient::Compile( { py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(executable, - PjRtExecutable::Compile(computation, pjrt_client_.get(), - std::move(options))); + pjrt_client_->Compile(computation, std::move(options))); TF_ASSIGN_OR_RETURN(fingerprint, pjrt_client_->ExecutableFingerprint(*executable)); } diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index 08249722d6c..37f5333ea1c 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -112,18 +112,18 @@ class PyClient : public std::enable_shared_from_this { int num_replicas); StatusOr CreateChannelHandle() { - return pjrt_client_->client()->CreateChannelHandle(); + return pjrt_client_->CreateChannelHandle(); } StatusOr CreateDeviceToHostChannelHandle() { - return pjrt_client_->client()->CreateDeviceToHostChannelHandle(); + return pjrt_client_->CreateDeviceToHostChannelHandle(); } StatusOr CreateHostToDeviceChannelHandle() { - return pjrt_client_->client()->CreateHostToDeviceChannelHandle(); + return pjrt_client_->CreateHostToDeviceChannelHandle(); } StatusOr> BufferFromPyval( const pybind11::object& argument, PjRtDevice* device, bool force_copy, - PjRtBuffer::HostBufferSemantics host_buffer_semantics); + PjRtClient::HostBufferSemantics host_buffer_semantics); StatusOr> Compile( const XlaComputation& computation, CompileOptions options); diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index 53891b96846..9d1b89a1cbc 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -135,15 +135,7 @@ PyExecutable::ExecuteOnLocalDevices( StatusOr>> PyExecutable::HloModules() const { - std::vector> modules; - modules.reserve(executable_->executables().size()); - for (const auto& local_exec : executable_->executables()) { - if (!local_exec->executable()->has_module()) { - return InvalidArgument("Executable does not have HLO modules."); - } - modules.push_back(local_exec->executable()->shared_module()); - } - return std::move(modules); + return executable_->GetHloModules(); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 0602d096aaa..6cd55d0e631 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,7 +37,7 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform, + : xla::PjRtDevice(id, /*local_device_state=*/nullptr, /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index ee2e2fa57b1..a5a6cbabb82 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -736,24 +736,38 @@ class PodTpuDriver : public TpuDriver { auto done = [this, event_id]() { mu_.AssertHeld(); - if (events_.count(event_id) == 0) { - LOG(ERROR) << "Cannot find event id " << event_id - << " in WaitForEvent."; - } - return events_[event_id]->underlying_event != nullptr && - events_[event_id]->underlying_event.use_count() != 0; + // The event was either completed and erased from the map or we have + // an underlying event available to us. + return events_.count(event_id) == 0 || + (events_[event_id]->underlying_event != nullptr && + events_[event_id]->underlying_event.use_count() != 0); }; auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); if (!status) { return absl::nullopt; } - underlying_event = events_[event_id]->underlying_event; + + if (events_.count(event_id) > 0) { + underlying_event = events_[event_id]->underlying_event; + } else { + underlying_event = nullptr; + } } // Wait for the underlying event without holding on to the event_lock_, or // else incoming events will not be processed. - return underlying_event->AwaitWithTimeout(duration); + if (underlying_event != nullptr) { + return underlying_event->AwaitWithTimeout(duration); + } else { + absl::MutexLock l(&mu_); + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + return Status::OK(); + } else { + return event_status->second; + } + } } void AddCallbackForEvent(int64_t event_id, std::function fn) @@ -768,13 +782,13 @@ class PodTpuDriver : public TpuDriver { } else { fn(event_status->second); } - } - - if (event->second->underlying_event != nullptr && - event->second->underlying_event.use_count() != 0) { - event->second->underlying_event->AddCallback(fn); } else { - event->second->callbacks.push_back(std::move(fn)); + if (event->second->underlying_event != nullptr && + event->second->underlying_event.use_count() != 0) { + event->second->underlying_event->AddCallback(fn); + } else { + event->second->callbacks.push_back(std::move(fn)); + } } } diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 0e92c85f6f6..fb8c0ba0ba4 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/interpreter_device.h" #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" #include "tensorflow/compiler/xla/python/jax_jit.h" @@ -205,6 +206,23 @@ bool IsOptimizedBuild() { #endif // NDEBUG } +// Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on +// invalid input. +StatusOr MakeShapeWithLayout( + PrimitiveType element_type, absl::Span dims, + absl::optional> minor_to_major) { + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } else { + shape.clear_layout(); + } + return shape; +} + } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -261,15 +279,13 @@ PYBIND11_MODULE(xla_extension, m) { .def_static( "array_shape", [](PrimitiveType type, py::object dims_seq, - absl::optional layout_seq) -> Shape { + absl::optional layout_seq) -> StatusOr { std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); - return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + return MakeShapeWithLayout(type, dims, layout); } else { - Shape shape = ShapeUtil::MakeShape(type, dims); - shape.clear_layout(); - return shape; + return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), @@ -277,16 +293,14 @@ PYBIND11_MODULE(xla_extension, m) { .def_static( "array_shape", [](py::dtype dtype, py::object dims_seq, - absl::optional layout_seq) -> Shape { + absl::optional layout_seq) -> StatusOr { PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); std::vector dims = IntSequenceToVector(dims_seq); if (layout_seq) { std::vector layout = IntSequenceToVector(*layout_seq); - return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + return MakeShapeWithLayout(type, dims, layout); } else { - Shape shape = ShapeUtil::MakeShape(type, dims); - shape.clear_layout(); - return shape; + return MakeShapeWithLayout(type, dims, absl::nullopt); } }, "Constructs an array shape.", py::arg("type"), py::arg("dims"), @@ -465,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) { [](const PjRtDevice& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - return local_device->client()->TransferToInfeedLocal( - literal, local_device->device_ordinal()); + return device.TransferToInfeed(literal); }) .def("transfer_from_outfeed", [](const PjRtDevice& device, @@ -477,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr literal_shared; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); Shape shape_with_layout = shape; ShapeUtil::ForEachMutableSubshape( &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { @@ -486,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) { LayoutUtil::SetToDefaultLayout(subshape); } }); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape_with_layout, local_device->device_ordinal())); + TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed( + shape_with_layout)); literal_shared = std::make_shared(std::move(literal)); } @@ -521,12 +528,12 @@ PYBIND11_MODULE(xla_extension, m) { .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("BFC", GpuAllocatorConfig::Kind::kBFC); - py::enum_(m, "HostBufferSemantics") + py::enum_(m, "HostBufferSemantics") .value("IMMUTABLE_ONLY_DURING_CALL", - PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall) + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes) - .value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy); py::class_> py_local_client(m, "Client"); py_local_client.def_property_readonly("platform", &PyClient::platform_name) @@ -548,7 +555,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"), py::arg("device") = nullptr, py::arg("force_copy") = false, py::arg("host_buffer_semantics") = - PjRtBuffer::HostBufferSemantics::kZeroCopy) + PjRtClient::HostBufferSemantics::kZeroCopy) .def("compile", &PyClient::Compile, py::arg("computation"), py::arg("compile_options") = CompileOptions()) .def("heap_profile", &PyClient::HeapProfile); @@ -580,6 +587,14 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("asynchronous") = true, py::arg("allocator_config") = GpuAllocatorConfig(), py::arg("distributed_client") = nullptr, py::arg("node_id") = 0); + m.def( + "get_tpu_client", + [](bool asynchronous) -> StatusOr> { + TF_ASSIGN_OR_RETURN(std::shared_ptr client, + GetTpuClient(asynchronous)); + return std::make_shared(std::move(client)); + }, + py::arg("asynchronous") = true); py::class_(m, "Frame") .def_readonly("file_name", &Traceback::Frame::file_name) @@ -626,9 +641,7 @@ PYBIND11_MODULE(xla_extension, m) { [](py::object buffer_obj) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); PyBuffer* buffer = buffer_obj.cast(); - LocalDeviceState* state = - buffer->buffer()->device()->local_device_state(); - if (state->executor()->platform_kind() == se::PlatformKind::kHost && + if (buffer->buffer()->IsOnCpu() && buffer->buffer()->on_device_shape().IsArray() && buffer->buffer()->on_device_shape().element_type() != BF16) { py::object out = py::reinterpret_steal( @@ -877,7 +890,8 @@ PYBIND11_MODULE(xla_extension, m) { ShapeIndex(param_index.begin(), param_index.end())); }); - m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); + m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor, + py::arg("buffer"), py::arg("take_ownership") = true); m.def("dlpack_managed_tensor_to_buffer", DLPackManagedTensorToBuffer); py::enum_(m, "PrecisionConfig_Precision") diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 133483d2bb9..3de0ffcc2f8 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -90,11 +90,16 @@ def _gpu_backend_factory(distributed_client=None, node_id=0): node_id=node_id) +def _tpu_backend_factory(): + return _xla.get_tpu_client(asynchronous=True) + + # Backend factories, keyed by user-visible name, in increasing priority order. _local_backend_factories = collections.OrderedDict([ ('interpreter', _interpreter_backend_factory), ('cpu', _cpu_backend_factory), ('gpu', _gpu_backend_factory), + ('tpu', _tpu_backend_factory), ]) diff --git a/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py new file mode 100644 index 00000000000..180bb040cc4 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_backend_independent_test.py @@ -0,0 +1,147 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-independent tests for the Python XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from absl.testing import absltest +import numpy as np + +from tensorflow.compiler.xla.python import xla_client + +# pylint: disable=g-import-not-at-top +try: + import portpicker +except ImportError: + portpicker = None +# pylint: enable=g-import-not-at-top + +ops = xla_client.ops + + +class ShapeTest(absltest.TestCase): + + def testInvalidShapes(self): + with self.assertRaisesRegex(RuntimeError, + "shape's dimensions must not be < 0.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field contains 1 element.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], + [1, -1]) + + +class ComputationPrinting(absltest.TestCase): + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder1, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + +class AliasTest(absltest.TestCase): + + def testSetUpAlias(self): + c = xla_client.XlaBuilder(self.id()) + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c.build(out) + + +class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 3863d8a1481..0eaa7dabc61 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the Python extension-based XLA client.""" +"""Backend-dependent tests for the Python XLA client.""" from __future__ import absolute_import from __future__ import division @@ -37,12 +36,6 @@ try: except ImportError: custom_call_for_test = None -try: - import portpicker -except ImportError: - portpicker = None -# pylint: enable=g-import-not-at-top - bfloat16 = xla_client.bfloat16 ops = xla_client.ops @@ -142,27 +135,6 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Add(x, x) return builder.build() - def testComputationToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_text() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testComputationToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = computation.as_hlo_dot_graph() - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - def testHloModuleToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.as_hlo_module().to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testHloModuleToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.as_hlo_module()) - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - @unittest.skipIf(cloud_tpu, "not implemented") def testCompiledHloModuleToHloText(self): computation = self.ExampleComputation() @@ -182,29 +154,6 @@ def TestFactory(xla_backend, cloud_tpu=False): tests.append(ComputationPrinting) - class ComputationHashTest(absltest.TestCase): - - def testHash(self): - builder0 = xla_client.XlaBuilder("computation0") - p0 = ops.Parameter(builder0, 0, - xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) - ops.Mul(p0, p1) - computation0 = builder0.build() - - builder1 = xla_client.XlaBuilder("computation1") - p0 = ops.Parameter(builder1, 0, - xla_client.shape_from_pyval(np.float32(0))) - p1 = ops.Parameter( - builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) - ops.Mul(p0, p1) - computation1 = builder1.build() - - self.assertEqual(computation0.hash(), computation1.hash()) - - tests.append(ComputationHashTest) - class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" @@ -556,6 +505,7 @@ def TestFactory(xla_backend, cloud_tpu=False): self._ExecuteAndCompareExact( c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) + # pyformat: disable @parameterized.named_parameters({ "testcase_name": "_{}_{}".format(src_dtype.__name__, dst_dtype.__name__), @@ -563,6 +513,7 @@ def TestFactory(xla_backend, cloud_tpu=False): "dst_dtype": dst_dtype, } for src_dtype, dst_dtype in itertools.permutations( [np.bool, np.int32, np.int64, np.float32, np.float64], 2)) + # pyformat: enable def testConvertElementType(self, src_dtype, dst_dtype): if ((src_dtype in [np.int64, np.float64] or dst_dtype in [np.int64, np.float64]) and @@ -582,6 +533,7 @@ def TestFactory(xla_backend, cloud_tpu=False): self.assertEqual(result[0].dtype, expected.dtype) np.testing.assert_equal(result[0], expected) + # pyformat: disable @parameterized.named_parameters( { "testcase_name": "_{}_{}".format(src_dtype.__name__, @@ -591,6 +543,7 @@ def TestFactory(xla_backend, cloud_tpu=False): } for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) + # pyformat: enable def testBitcastConvertType(self, src_dtype, dst_dtype): if (np.float64 in (src_dtype, dst_dtype) and self.backend.platform == "tpu"): @@ -1901,24 +1854,6 @@ def TestFactory(xla_backend, cloud_tpu=False): tests.append(SetShardingTest) - class AliasTest(ComputationTest): - - def testSetUpAlias(self): - c = self._NewComputation() - p1 = ops.Parameter( - c, 0, - xla_client.shape_from_pyval( - NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) - p2 = ops.Parameter( - c, 1, - xla_client.shape_from_pyval( - NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) - out = ops.Add(p1, p2) - c.setup_alias([], 0, []) - c = c.build(out) - - tests.append(AliasTest) - testcase_shapes = [ (), (1,), @@ -1944,35 +1879,67 @@ def TestFactory(xla_backend, cloud_tpu=False): self.skipTest("DLPack requires CPU or GPU") # pylint: disable=g-complex-comprehension + # pyformat: disable @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), + "testcase_name": "{}_own={}".format(FormatShapeAndDtype(shape, dtype), + take_ownership), "dtype": dtype, - "shape": shape - } for dtype in dlpack_dtypes for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): + "shape": shape, + "take_ownership": take_ownership + } for dtype in dlpack_dtypes for shape in testcase_shapes + for take_ownership in [False, True]) + # pyformat: enable + def testRoundTrip(self, dtype, shape, take_ownership): x = np.array(np.random.rand(*shape) * 100, dtype=dtype) buffer = self.backend.buffer_from_pyval(x) - dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=take_ownership) del buffer # Free "buffer" to make sure dlt retains ownership. self.assertEqual(type(dlt).__name__, "PyCapsule") - y = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlt, self.backend) + y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) np.testing.assert_array_equal(x, y.to_py()) def testTensorsCanBeConsumedOnceOnly(self): x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) buffer = self.backend.buffer_from_pyval(x) - dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=True) def ConsumeDLPackTensor(): - _ = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlt, self.backend) + _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) ConsumeDLPackTensor() self.assertRaisesRegex( RuntimeError, ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) + def testTensorsCanBeOwnedOnceOnly(self): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + _ = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=True) + self.assertTrue(buffer.is_deleted()) + with self.assertRaisesRegex( + RuntimeError, + "Cannot convert deleted/invalid buffer to DLPack tensor.*"): + _ = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=True) + + def testNonOwnedDlpackCanBeViewedTwice(self): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + d1 = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=False) + d2 = xla_client._xla.buffer_to_dlpack_managed_tensor( + buffer, take_ownership=False) + + y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend) + z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend) + del d1, d2 + np.testing.assert_array_equal(x, buffer.to_py()) + np.testing.assert_array_equal(x, y.to_py()) + np.testing.assert_array_equal(x, z.to_py()) + tests.append(DLPackTest) class BufferProtocolTest(parameterized.TestCase): @@ -2022,28 +1989,6 @@ def TestFactory(xla_backend, cloud_tpu=False): tests.append(BufferProtocolTest) - class ProfilerTest(absltest.TestCase): - - def testTraceMe(self): - # TODO(phawkins): These tests just check that the TraceMe context manager - # acts like a context manager and doesn't explode. Ideally we'd check that - # the profiler saw the traceme too. - with xla_client.profiler.TraceMe("test1"): - pass - with xla_client.profiler.TraceMe("test2", foo=123): - pass - with self.assertRaises(ValueError): - with xla_client.profiler.TraceMe("test3"): - raise ValueError("test") - - @unittest.skipIf(portpicker is None, "Test requires portpicker") - def testStartServer(self): - port = portpicker.pick_unused_port() - server = xla_client.profiler.start_server(port) - del server - - tests.append(ProfilerTest) - class TracebackTest(absltest.TestCase): def setUp(self): diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 491d1d67877..e16575bebd4 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -241,7 +241,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 80c5513bb9d..e5c59fc0c7a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -45,8 +45,9 @@ cc_library( ) filegroup( - name = "single_threaded_runtime_srcs", + name = "runtime_srcs", srcs = [ + # Single-threaded support. "runtime_fp16.cc", "runtime_key_value_sort.cc", "runtime_pow.cc", @@ -54,13 +55,20 @@ filegroup( "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", "runtime_topk.cc", + ] + [ + # Multi-threaded support. + "runtime_conv2d.cc", + "runtime_fft.cc", + "runtime_matmul.cc", + "runtime_fork_join.cc", ], visibility = [":friends"], ) filegroup( - name = "single_threaded_runtime_hdrs", + name = "runtime_hdrs", srcs = [ + # Single-threaded support. "runtime_conv2d_impl.h", "runtime_fft_impl.h", "runtime_fp16.h", @@ -70,6 +78,13 @@ filegroup( "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", "runtime_topk.h", + ] + [ + # Multi-threaded support. + "runtime_conv2d.h", + "runtime_fft.h", + "runtime_fork_join.h", + "runtime_lightweight_check.h", + "runtime_matmul.h", ], visibility = [":friends"], ) @@ -487,7 +502,6 @@ cc_library( ":cpu_runtime", ":ir_emission_utils", ":mlir_emitter", - ":mlir_matmul_codegen_strategy", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -509,6 +523,7 @@ cc_library( "@llvm-project//mlir:EDSC", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:StandardOps", ], ) @@ -552,6 +567,7 @@ cc_library( "@llvm-project//llvm:IPO", "@llvm-project//llvm:MC", "@llvm-project//llvm:Object", + "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", ], @@ -1143,24 +1159,3 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", ], ) - -cc_library( - name = "mlir_matmul_codegen_strategy", - srcs = ["mlir_matmul_codegen_strategy.cc"], - hdrs = ["mlir_matmul_codegen_strategy.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorOps", - "@llvm-project//mlir:VectorToSCF", - ], -) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index a21ace0d8b2..643de6c4e58 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -80,8 +80,8 @@ class FilteredPassManager : public llvm::legacy::PassManager { }; } // anonymous namespace -std::unique_ptr CompilerFunctor::operator()( - llvm::Module& module) const { +llvm::Expected> CompilerFunctor::operator()( + llvm::Module& module) { FilteredPassManager module_passes(disable_expensive_passes_); llvm::legacy::FunctionPassManager function_passes(&module); @@ -155,7 +155,7 @@ std::unique_ptr CompilerFunctor::operator()( } } - return memory_buffer; + return std::move(memory_buffer); } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 647f0d18ef5..6211588861b 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_ +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -29,7 +30,7 @@ namespace cpu { // Functor class for compiling an LLVM module down to an object file. For use by // Orc JIT compile layer. -class CompilerFunctor { +class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler { public: explicit CompilerFunctor( llvm::TargetMachine* target_machine, int opt_level, @@ -39,7 +40,8 @@ class CompilerFunctor { LLVMCompiler::ModuleHook post_optimization_hook = nullptr, std::function post_codegen_hook = nullptr) - : target_machine_(target_machine), + : IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()), + target_machine_(target_machine), opt_level_(opt_level), optimize_for_size_(optimize_for_size), disable_expensive_passes_(disable_expensive_passes), @@ -49,8 +51,8 @@ class CompilerFunctor { post_codegen_hook_(std::move(post_codegen_hook)) {} // Compile a Module to an ObjectFile. - std::unique_ptr operator()( - llvm::Module& module) const; // NOLINT + llvm::Expected> operator()( + llvm::Module& module) override; private: // Populates the given pass manager with TargetLibraryInfo and diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 1ffafd37a27..e92f890ba67 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" @@ -645,11 +647,11 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; LoadMLIRDialects(mlir_context); - llvm::LLVMContext llvm_context; + auto llvm_context = std::make_unique(); auto llvm_module = - absl::make_unique("__compute_module", llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = absl::make_unique( + auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -657,8 +659,12 @@ StatusOr> CpuCompiler::RunBackend( llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, OrcJITPostCompilationHook::Create(module.get())); - llvm_module->setDataLayout(jit->data_layout()); - llvm_module->setTargetTriple(jit->target_triple().getTriple()); + if (!jit) { + return InternalError("Creating JIT failed: %s", + llvm::toString(jit.takeError())); + } + llvm_module->setDataLayout((*jit)->data_layout()); + llvm_module->setTargetTriple((*jit)->target_triple().getTriple()); HloComputation* entry_computation = module->entry_computation(); std::unordered_map instruction_to_profile_idx; @@ -700,7 +706,7 @@ StatusOr> CpuCompiler::RunBackend( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); + LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine()); IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), @@ -739,7 +745,7 @@ StatusOr> CpuCompiler::RunBackend( string function_name = [&]() { llvm::SmallVector function_name_vector; llvm::Mangler::getNameWithPrefix( - function_name_vector, entry_function->getName(), jit->data_layout()); + function_name_vector, entry_function->getName(), (*jit)->data_layout()); return string(function_name_vector.begin(), function_name_vector.end()); }(); @@ -751,9 +757,11 @@ StatusOr> CpuCompiler::RunBackend( TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); + llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module), + std::move(llvm_context)); + cantFail((*jit)->AddModule(std::move(thread_safe_module))); cpu_executable.reset(new CpuExecutable( - std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(*jit), std::move(assignment), std::move(module), function_name, std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { @@ -971,7 +979,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook); std::unique_ptr object_file = - compiler_functor(llvm_module); + cantFail(compiler_functor(llvm_module)); ObjectFileData object_file_data(object_file->getBufferStart(), object_file->getBufferEnd()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 02bc445ce9a..456bb8c5a32 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -63,14 +63,14 @@ CpuExecutable::CpuExecutable( assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. - llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name); + llvm::Expected sym = + jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. - CHECK(sym) << "Symbol " << entry_function_name << " not found."; + CHECK(*sym) << "Symbol " << entry_function_name << " not found."; // getAddress can do work under the hood in the jit, so it needs to be // guarded by the mutex. - compute_function_ = - reinterpret_cast(cantFail(sym.getAddress())); + compute_function_ = reinterpret_cast(sym->getAddress()); VLOG(1) << "compute_function_ at address " << reinterpret_cast(compute_function_); } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 63d44af4a9e..ba8b74a64a5 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project #include "mlir/EDSC/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -322,7 +322,7 @@ Status DotOpEmitter::EmitLinalgMatmul() { int64 alignment = target_machine_features_.minimum_alignment_for_allocation( ShapeUtil::ByteSizeOf(dot_info_.result_shape)); - mlir_strategy::MatmulCodegenStrategy strategy; + mlir::linalg::CodegenStrategy strategy; strategy.tile(tilingOptions) .promote( mlir::linalg::LinalgPromotionOptions() diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc deleted file mode 100644 index ea89071a967..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Dominance.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/LoopUtils.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -using namespace mlir; // NOLINT -using namespace mlir::linalg; // NOLINT - -#define DEBUG_TYPE "matmul-codegen-strategy" - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -//===----------------------------------------------------------------------===// -// TODO: Cleanup and upstream these to go into core. Please ignore for now ! -//===----------------------------------------------------------------------===// -static void hoistRedundantCopies(FuncOp func) { - bool changed = true; - while (changed) { - changed = false; - func.walk([&](linalg::FillOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - // Hoist fill before. - op.getOperation()->moveBefore(loop); - changed = true; - }); - - func.walk([&](linalg::CopyOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - Value sourceView = op.getInput(0); - while (auto subViewOp = sourceView.getDefiningOp()) - sourceView = subViewOp.getViewSource(); - - // Source traces back to a block argument. - if (sourceView.isa()) { - op.getOperation()->moveBefore(loop); - } else { - assert(sourceView.getDefiningOp() || - sourceView.getDefiningOp() || - sourceView.getDefiningOp()); - op.getOperation()->moveAfter(loop); - } - changed = true; - }); - } -} - -/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing: -/// `%lb + %step * new_dim` where -/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an -/// AffineDimExpr depending on whether the value is constant or not. -/// 2. the AffineExpr for %step is either an AffineConstantExpr or an -/// AffineSymbolExpr depending on whether the value is constant or not. -/// -static void substitute(scf::ForOp forOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = forOp.getContext(); - auto lbConstant = forOp.lowerBound().getDefiningOp(); - AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - - auto stepConstant = forOp.step().getDefiningOp(); - AffineExpr step = stepConstant - ? getAffineConstantExpr(stepConstant.getValue(), ctx) - : getAffineSymbolExpr(symbols.size(), ctx); - - if (!lbConstant) dims.push_back(forOp.lowerBound()); - if (!stepConstant) symbols.push_back(forOp.step()); - exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx)); - - auto ubConstant = forOp.upperBound().getDefiningOp(); - AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - if (!ubConstant) dims.push_back(forOp.upperBound()); - exprs.push_back(ub); - - dims.push_back(forOp.getInductionVar()); -} - -/// Traverse the . -static void substitute(AffineMinOp minOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = minOp.getContext(); - for (Value v : minOp.getDimOperands()) { - if (auto forOp = scf::getForInductionVarOwner(v)) { - substitute(forOp, exprs, dims, symbols); - continue; - } - if (auto parentMinOp = v.getDefiningOp()) { - substitute(parentMinOp, exprs, dims, symbols); - continue; - } - exprs.push_back(getAffineDimExpr(dims.size(), ctx)); - dims.push_back(v); - } -} - -/// Perform folding of chains of AffineMinOp. -struct AffineMinCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AffineMinOp minOp, - PatternRewriter &rewriter) const override; -}; - -LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite( - AffineMinOp minOp, PatternRewriter &rewriter) const { - LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: " - << *minOp.getOperation() << "\n"); - - int64_t min = std::numeric_limits::max(); - for (auto e : minOp.map().getResults()) - if (auto cstExpr = e.dyn_cast()) - min = std::min(min, cstExpr.getValue()); - if (min == std::numeric_limits::max()) return failure(); - - SmallVector exprs; - SmallVector dims, symbols; - substitute(minOp, exprs, dims, symbols); - - SmallVector operands = dims; - operands.append(symbols.begin(), symbols.end()); - - MLIRContext *ctx = minOp.getContext(); - auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); - LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n"); - - SmallVector modExprs; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) - modExprs.push_back(getAffineDimExpr(idx, ctx) % min); - map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map); - canonicalizeMapAndOperands(&map, &operands); - map = simplifyAffineMap(map); - - LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n"; - llvm::interleaveComma(operands, llvm::dbgs())); - - if (!llvm::all_of(map.getResults(), [](AffineExpr e) { - if (auto cst = e.dyn_cast()) - return cst.getValue() == 0; - return false; - })) - return failure(); - - rewriter.replaceOpWithNewOp(minOp, min); - return success(); -} -//===----------------------------------------------------------------------===// -// END TODO -//===----------------------------------------------------------------------===// - -void MatmulCodegenStrategy::transform(FuncOp func) const { - MLIRContext *context = func.getContext(); - // Emplace patterns one at a time while also maintaining a simple chained - // state transition. - unsigned stepCount = 0; - SmallVector stage1Patterns; - auto zeroState = Identifier::get(std::to_string(stepCount), context); - auto currentState = zeroState; - for (auto &t : transformation_sequence) { - auto nextState = Identifier::get(std::to_string(++stepCount), context); - auto marker = (currentState == zeroState) - ? linalg::LinalgMarker({}, nextState) - : linalg::LinalgMarker(currentState, nextState); - stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); - currentState = nextState; - } - - OwningRewritePatternList stage2Patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - stage2Patterns.insert(context); - - auto stage3Transforms = [](Operation *op) { - // Some of these may be too aggressive as a stage 3 that is applied on each - // stage 1 application and may have to be split out to post staged patterns - // application (in which case they could just be passes, TBD). - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(op->getParentOfType()))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); - promoteSingleIterationLoops(cast(op)); - hoistViewAllocOps(cast(op)); - hoistRedundantVectorTransfers(cast(op)); - hoistRedundantCopies(cast(op)); - return success(); - }; - linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, - stage3Transforms); - - //===--------------------------------------------------------------------===// - // Post staged patterns transforms - //===--------------------------------------------------------------------===// - // Programmatic controlled lowering of vector.contract only. - OwningRewritePatternList vectorContractLoweringPatterns; - vectorContractLoweringPatterns - .insert( - vector_transforms_options, context); - applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns); - - // Programmatic controlled lowering of vector.transfer only. - OwningRewritePatternList vectorToLoopsPatterns; - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, - vector_to_scf_options); - applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns); -} - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h deleted file mode 100644 index 3b11b750c47..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ -#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -/// Abstract Transformation class applied in a sequence that also handles state -/// through markers. -struct Transformation { - virtual ~Transformation() = default; - virtual mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0; - mlir::linalg::LinalgMarker marker; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Tile`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Tile : public Transformation { - explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList tiling_patterns; - tiling_patterns.insert>( - context, options, m); - return tiling_patterns; - } - - private: - mlir::linalg::LinalgTilingOptions options; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Promote`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Promote : public Transformation { - explicit Promote(mlir::linalg::LinalgPromotionOptions options) - : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList promotion_patterns; - promotion_patterns - .insert>(context, - options, m); - return promotion_patterns; - } - - private: - mlir::linalg::LinalgPromotionOptions options; -}; - -/// Vectorization transformation enqueues a particular stage-1 pattern for -/// `LinalgVectorizationPattern` as well as copy to vector -/// transfer rewrite forwarding patterns. -// TODO: variadic LinalgOpTypes. -template -struct Vectorize : public Transformation { - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList vectorization_patterns; - // FillOp may interfere with forwarding patterns atm, so we bump up the - // priority of LinalgCopyVTRForwardingPattern / - // LinalgCopyVTWForwardingPattern. - vectorization_patterns - .insert>(context, - m); - vectorization_patterns.insert( - context, - /*benefit=*/2); - return vectorization_patterns; - } -}; - -/// Matmul-specific strategy object controls how a linalg.matmul is -/// progressively lowered. -/// The strategy uses a 3-level staged patterns strategy which allows ordering -/// transformations by using the Linalg `applyStagedPatterns` function, where: -/// 1. The first stage consists of the successive `tile`, `promote` and -/// `vectorize` patterns, applied sequentially. -/// 2. The second stage consists of common local canonicalization patterns -/// that are applied eagerly after each stage-1 pattern. -/// 3. the third stage consists of more global transformation, also applied -/// eagerly, after all stage-2 patterns. Such more global transformations -struct MatmulCodegenStrategy { - /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling - /// `options`. - template - MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) { - transformation_sequence.emplace_back(new Tile(options)); - return *this; - } - /// Conditionally append a pattern to add a level of tiling for `LinalgOpType` - /// with tiling `options`. - template - MatmulCodegenStrategy &tileIf(bool b, - mlir::linalg::LinalgTilingOptions options) { - return b ? tile(options) : *this; - } - /// Append a pattern to add a level of promotion for `LinalgOpType` with - /// promotion `options`. - template - MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) { - transformation_sequence.emplace_back(new Promote(options)); - return *this; - } - /// Conditionally append a pattern to add a level of promotion for - /// `LinalgOpType` with promotion `options`. - template - MatmulCodegenStrategy &promoteIf( - bool b, mlir::linalg::LinalgPromotionOptions options) { - return b ? promote(options) : *this; - return *this; - } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - template - MatmulCodegenStrategy &vectorize() { - transformation_sequence.emplace_back(new Vectorize()); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - template - MatmulCodegenStrategy &vectorizeIf(bool b) { - return b ? vectorize() : *this; - return *this; - } - /// Configure the post staged-patterns late vector transformations. - MatmulCodegenStrategy &setVectorTransformsOptions( - mlir::vector::VectorTransformsOptions options) { - vector_transforms_options = options; - return *this; - } - /// Configure the post staged-patterns late vector.transfer to scf conversion. - MatmulCodegenStrategy &setVectorTransferToSCFOptions( - mlir::VectorTransferToSCFOptions options) { - vector_to_scf_options = options; - return *this; - } - - /// Apply the transformation patterns in sequence with cleanup transformations - /// interleaved. - void transform(mlir::FuncOp func) const; - - private: - mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const; - - mlir::vector::VectorTransformsOptions vector_transforms_options; - mlir::VectorTransferToSCFOptions vector_to_scf_options; - llvm::SmallVector, 4> transformation_sequence; -}; - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla - -#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 28508bde4cd..41f24af6652 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -85,6 +85,8 @@ SimpleOrcJIT::InferTargetMachineForJIT( } SimpleOrcJIT::SimpleOrcJIT( + std::unique_ptr target_process_control, + std::unique_ptr execution_session, const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, @@ -93,48 +95,89 @@ SimpleOrcJIT::SimpleOrcJIT( std::function post_codegen_hook) : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), data_layout_(target_machine_->createDataLayout()), - symbol_resolver_(llvm::orc::createLegacyLookupResolver( - execution_session_, - [this](llvm::StringRef name) -> llvm::JITSymbol { - return this->ResolveRuntimeSymbol(std::string(name)); - }, - [](llvm::Error Err) { - cantFail(std::move(Err), "lookupFlags failed"); - })), - object_layer_( - execution_session_, - [this](llvm::orc::VModuleKey) { - llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; - result.MemMgr = std::make_shared( - orc_jit_memory_mapper::GetInstance()); - result.Resolver = symbol_resolver_; - return result; - }, - /*NotifyLoaded=*/ - llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), - /*NotifyFinalized=*/ - [this](VModuleKeyT, const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { - this->NotifyObjectFinalized(object, object_info); - }, - /*NotifyFreed=*/ - [this](VModuleKeyT, const llvm::object::ObjectFile& object) { - this->NotifyObjectFreed(object); - }), + target_process_control_(std::move(target_process_control)), + execution_session_(std::move(execution_session)), + object_layer_(*execution_session_, + []() { + return std::make_unique( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( - object_layer_, - CompilerFunctor(target_machine_.get(), opt_level, optimize_for_size, - disable_expensive_passes, fast_math_flags, - std::move(pre_optimization_hook), - std::move(post_optimization_hook), - std::move(post_codegen_hook))), + *execution_session_, object_layer_, + std::make_unique( + target_machine_.get(), opt_level, optimize_for_size, + disable_expensive_passes, fast_math_flags, + std::move(pre_optimization_hook), + std::move(post_optimization_hook), std::move(post_codegen_hook))), + main_jit_dylib_(&execution_session_->createBareJITDylib("
")), gdb_jit_event_listener_( llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); + + // Materialize unknown symbols from the runtime symbol table. + class RuntimeSymbolGenerator : public llvm::orc::DefinitionGenerator { + SimpleOrcJIT& jit_; + + public: + explicit RuntimeSymbolGenerator(SimpleOrcJIT& jit) : jit_(jit) {} + llvm::Error tryToGenerate( + llvm::orc::LookupState&, llvm::orc::LookupKind, + llvm::orc::JITDylib& jit_dylib, llvm::orc::JITDylibLookupFlags, + const llvm::orc::SymbolLookupSet& names) override { + llvm::orc::SymbolMap new_defs; + + for (const auto& kv : names) { + const auto& name = kv.first; + if (llvm::JITEvaluatedSymbol symbol = + jit_.ResolveRuntimeSymbol(*name)) { + new_defs[name] = symbol; + } + } + + cantFail(jit_dylib.define(absoluteSymbols(std::move(new_defs)))); + return llvm::Error::success(); + } + }; + main_jit_dylib_->addGenerator( + std::make_unique(*this)); + object_layer_.registerJITEventListener(*this); + + // Copied from LLJIT, required to find symbols on Windows. + if (target_machine_->getTargetTriple().isOSBinFormatCOFF()) { + object_layer_.setOverrideObjectFlagsWithResponsibilityFlags(true); + object_layer_.setAutoClaimResponsibilityForObjectSymbols(true); + } } -llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { +SimpleOrcJIT::~SimpleOrcJIT() { + if (auto err = execution_session_->endSession()) { + execution_session_->reportError(std::move(err)); + } +} + +llvm::Expected> SimpleOrcJIT::Create( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook) { + auto target_process_control = llvm::orc::SelfTargetProcessControl::Create(); + if (!target_process_control) { + return target_process_control.takeError(); + } + + auto execution_session = std::make_unique(); + return std::make_unique( + std::move(*target_process_control), std::move(execution_session), + target_options, opt_level, optimize_for_size, disable_expensive_passes, + fast_math_flags, std::move(pre_optimization_hook), + std::move(post_optimization_hook), std::move(post_codegen_hook)); +} + +llvm::JITEvaluatedSymbol SimpleOrcJIT::ResolveRuntimeSymbol( + llvm::StringRef name) { void* func_addr = nullptr; if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { // On Mac OS X, 'name' may have a leading underscore prefix, even though the @@ -143,12 +186,13 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host"); } else { - func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host"); + func_addr = + xla::CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host"); } if (func_addr == nullptr) { LOG(ERROR) - << "Unable to resolve runtime symbol: `" << name + << "Unable to resolve runtime symbol: `" << name.str() << "'. Hint: if the symbol a custom call target, make sure you've " "registered it with the JIT using " "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; @@ -159,60 +203,25 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { return symbol_info; } -void SimpleOrcJIT::NotifyObjectFinalized( +void SimpleOrcJIT::notifyObjectLoaded( + llvm::JITEventListener::ObjectKey key, const llvm::object::ObjectFile& object, const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { - uint64_t key = static_cast( - reinterpret_cast(object.getData().data())); gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); size_of_generated_code_in_bytes_ += object.getData().size(); } -void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { - uint64_t key = static_cast( - reinterpret_cast(object.getData().data())); +void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) { gdb_jit_event_listener_->notifyFreeingObject(key); } -SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( - std::unique_ptr module) { - auto key = execution_session_.allocateVModule(); - cantFail(compile_layer_.addModule(key, std::move(module))); - module_keys_.push_back(key); - return key; +llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module) { + return compile_layer_.add(*main_jit_dylib_, std::move(module)); } -void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { - module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), - module_keys_.end()); - cantFail(compile_layer_.removeModule(key)); -} - -llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { -#ifdef _WIN32 - // The symbol lookup of ObjectLinkingLayer uses the SymbolRef::SF_Exported - // flag to decide whether a symbol will be visible or not, when we call - // IRCompileLayer::findSymbolIn with ExportedSymbolsOnly set to true. - // - // But for Windows COFF objects, this flag is currently never set. - // For a potential solution see: https://reviews.llvm.org/rL258665 - // For now, we allow non-exported symbols on Windows as a workaround. - const bool exported_symbols_only = false; -#else - const bool exported_symbols_only = true; -#endif - - // Resolve symbol from last module to first, allowing later redefinitions of - // symbols shadow earlier ones. - for (auto& key : - llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { - if (auto symbol = - compile_layer_.findSymbolIn(key, name, exported_symbols_only)) { - return symbol; - } - } - - return nullptr; +llvm::Expected SimpleOrcJIT::FindCompiledSymbol( + const std::string& name) { + return execution_session_->lookup({main_jit_dylib_}, name); } #if defined(PLATFORM_WINDOWS) diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 9c470edbac2..36c32ed23e6 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" +#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -42,13 +43,10 @@ namespace cpu { // Supports JIT-ing multiple modules but without cross-module linking. // Implements eager compilation - the module is lowered to binary as soon as // it's added to the JIT. -class SimpleOrcJIT { +class SimpleOrcJIT : public llvm::JITEventListener { public: - using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer; - using CompileFtor = - std::function(llvm::Module&)>; - using CompileLayerT = llvm::orc::LegacyIRCompileLayer; - using VModuleKeyT = llvm::orc::VModuleKey; + using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using CompileLayerT = llvm::orc::IRCompileLayer; // Create a new JIT, targeting the host architecture. // @@ -56,6 +54,8 @@ class SimpleOrcJIT { // LLVM IR-level optimizations. post_codegen_hook is invoked after // compiling to machine code. SimpleOrcJIT( + std::unique_ptr target_process_control, + std::unique_ptr execution_session, const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, @@ -63,22 +63,28 @@ class SimpleOrcJIT { LLVMCompiler::ModuleHook post_optimization_hook, std::function post_codegen_hook); + static llvm::Expected> Create( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, llvm::FastMathFlags fast_math_flags, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook); + + ~SimpleOrcJIT() override; + const llvm::DataLayout& data_layout() const { return data_layout_; } const llvm::Triple& target_triple() const { return target_machine_->getTargetTriple(); } - // Add a module to the JIT. Returns an opaque key that can be used to later - // remove this module. - VModuleKeyT AddModule(std::unique_ptr module); - - // Remove a module from the JIT and free the memory associated with it. - void RemoveModule(VModuleKeyT key); + llvm::Error AddModule(llvm::orc::ThreadSafeModule module); // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. - llvm::JITSymbol FindCompiledSymbol(const std::string& name); + llvm::Expected FindCompiledSymbol( + const std::string& name); llvm::TargetMachine* target_machine() const { return target_machine_.get(); } @@ -93,20 +99,21 @@ class SimpleOrcJIT { } private: - llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + llvm::JITEvaluatedSymbol ResolveRuntimeSymbol(llvm::StringRef name); - void NotifyObjectFinalized( + void notifyObjectLoaded( + llvm::JITEventListener::ObjectKey key, const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info); - void NotifyObjectFreed(const llvm::object::ObjectFile& object); + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override; + void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override; - std::vector module_keys_; std::unique_ptr target_machine_; const llvm::DataLayout data_layout_; - llvm::orc::ExecutionSession execution_session_; - std::shared_ptr symbol_resolver_; + std::unique_ptr target_process_control_; + std::unique_ptr execution_session_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + llvm::orc::JITDylib* main_jit_dylib_; int64 size_of_generated_code_in_bytes_ = 0; // Non owning pointer to a JIT event listener that registers the JIT events diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index b4c56113239..e728cd75caf 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -39,12 +39,17 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/platform/errors.h" namespace xla { namespace { +auto* dynamic_padding_gauge = tensorflow::monitoring::Gauge::New( + "/tensorflow/core/use_dynamic_padding_gauge", + "Tracks if dynamic padder is used."); + // ChooseIdentityValue looks at the instruction's operand, returns a // identity value which, when padded, doesn't change the result of the // instruction. @@ -1351,6 +1356,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { operand, input_dim, operand_dynamic_size, identity_value); TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); operand = inst->mutable_operand(operand_num); + dynamic_padding_gauge->GetCell()->Set(true); changed = true; } } @@ -1397,6 +1403,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); VLOG(2) << "Post DynamicPadder HLO:"; XLA_VLOG_LINES(2, module->ToString()); + dynamic_padding_gauge->GetCell()->Set(changed); return changed; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 98d523487b4..3a449b7c2db 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2540,6 +2540,10 @@ StatusOr ElementalIrEmitter::EmitElementalReduceWindow( // if I in bounds of input // value = function(value, input[I]) // output[O] = value + if (reduce_window->shape().IsTuple()) { + return Status(tensorflow::error::UNIMPLEMENTED, + "Variadic reduce window op is not yet fully supported."); + } const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b2ec656a2ba..d2c462e0957 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -288,6 +288,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", @@ -913,6 +914,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1186,6 +1188,7 @@ cc_library( ":gpu_layout_assignment", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", + ":horizontal_input_fusion", ":horizontal_loop_fusion", ":instruction_fusion", ":ir_emission_utils", @@ -1769,6 +1772,7 @@ cc_library( srcs = ["horizontal_loop_fusion.cc"], hdrs = ["horizontal_loop_fusion.h"], deps = [ + ":gpu_fusible", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", @@ -1805,6 +1809,45 @@ tf_cc_test( ], ) +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + ":gpu_fusible", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":horizontal_input_fusion", + ":multi_output_fusion", + "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "reduction_degenerate_dim_remover", srcs = ["reduction_degenerate_dim_remover.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index e9435f4fa92..25fbc0e05cb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -306,6 +307,10 @@ Status GpuCompiler::OptimizeHloModule( HloPassPipeline horizontal_fusion("horizontal_fusion"); horizontal_fusion.AddPass(); + // The code generated for fusions created by GpuHorizontalInputFusion has + // been observed to fail with CUDA_ERROR_ILLEGAL_ADDRESS errors. + // TODO(b/171227713): Re-enable once the emitters are fixed. + // horizontal_fusion.AddPass(); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); @@ -477,6 +482,47 @@ StatusOr> GpuCompiler::RunHloPasses( return std::move(module); } +static absl::optional DummyCanShareBufferFunction(const HloInstruction*, + const HloInstruction*, + const ShapeIndex&) { + return absl::nullopt; +} + +StatusOr< + std::tuple, std::unique_ptr>> +GpuCompiler::RunHloPassesAndBufferAssignement( + std::unique_ptr hlo_module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator, bool optimize) { + if (optimize) { + TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module), + executor, device_allocator)); + } + + std::unique_ptr stream_assignment = + AssignStreams(*hlo_module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size_)); + + auto buffer_size_bytes_function = + [this](const BufferValue& buffer_value) -> int64 { + return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size_); + }; + + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run( + hlo_module.get(), hlo_schedule->ConsumeHloOrdering(), + buffer_size_bytes_function, + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, DummyCanShareBufferFunction)); + + return std::make_tuple(std::move(hlo_module), std::move(assignment)); +} + // The order of `thunk_sequence` corresponds to // `hlo_schedule->ThunkLaunchOrder()`. static Status CompileModuleToLlvmIrImpl( @@ -722,12 +768,6 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } -static absl::optional DummyCanShareBufferFunction(const HloInstruction*, - const HloInstruction*, - const ShapeIndex&) { - return absl::nullopt; -} - StatusOr> CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 7b6e4c78832..824d7404ebe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -55,6 +55,13 @@ class GpuCompiler : public LLVMCompiler { std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; + StatusOr< + std::tuple, std::unique_ptr>> + RunHloPassesAndBufferAssignement(std::unique_ptr hlo_module, + se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator, + bool optimize) override; + Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index ce319b4c59d..b69b32c17c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr) { + if (instr.opcode() != HloOpcode::kFusion) { + return &instr; + } + auto fused_expression_root = instr.fused_expression_root(); + if (!instr.IsMultiOutputFusion()) { + return fused_expression_root; + } + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; +} + bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, const HloInstruction& instr2) { - // Returns the instructions that determines the emitter used for lowering, - // sometimes referred to as "the real hero". - auto get_real_hero = - [&](const HloInstruction* instr) -> const HloInstruction* { - if (instr->opcode() != HloOpcode::kFusion) { - return instr; - } - auto fused_expression_root = instr->fused_expression_root(); - if (!instr->IsMultiOutputFusion()) { - return fused_expression_root; - } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - }; - // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = get_real_hero(&instr1); - auto* instr_2 = get_real_hero(&instr2); + auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1); + auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -524,5 +522,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/, : HloInstruction::FusionKind::kLoop; } +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer) { + return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + // Skip GTE. + return IsConsumerTheOnlyNonRootUser(*user, consumer); + } + if (user == &consumer) { + // `user` is `consumer`. + return true; + } + if (user == user->parent()->root_instruction()) { + // Consumed by ROOT. + return true; + } + return false; + }); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e7cac6e55c8..9fa098a3394 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -71,6 +71,11 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, bool CreatesNestedLoop(const HloInstruction& producer, const HloInstruction& consumer); +// Returns the instruction that determines the emitter used for lowering, +// sometimes referred to as "the real hero". +const HloInstruction* GetRealHeroForMultiOutputFusion( + const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output @@ -100,6 +105,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer); +// Returns whether `consumer` is the only non-root user of `instr`. +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc new file mode 100644 index 00000000000..9287f9a92b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { +namespace gpu { + +namespace { + +// Gets the representative input shape of the multi-output fusion. +Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { + // Get the HLO that determines the emitter used for lowering. + const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + if (real_hero->operands().empty()) { + // Simply return an empty shape if the representative node has no input + // operands. + return Shape(); + } else { + return real_hero->operand(0)->shape(); + } +} + +class HorizontalInputFusionImpl { + public: + explicit HorizontalInputFusionImpl(HloComputation* computation) + : computation_(computation) {} + + ~HorizontalInputFusionImpl() {} + + StatusOr Run(); + + private: + HloComputation* computation_; +}; // HorizontalInputFusionImpl + +// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to +// right. +bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, + const Shape& shape_b) { + if (shape_a.rank() != shape_b.rank()) { + return shape_a.rank() < shape_b.rank(); + } + auto dims_a = shape_a.dimensions(); + auto dims_b = shape_b.dimensions(); + for (size_t i = 0; i < dims_a.size(); ++i) { + if (dims_a[i] != dims_b[i]) { + return dims_a[i] < dims_b[i]; + } + } + return true; +} + +std::vector FindAndSortFusionCandidates( + HloInstruction* consumer) { + absl::flat_hash_set fusion_instr_set; + for (auto opnd : consumer->operands()) { + HloInstruction* predecessor = opnd->LatestNonGteAncestor(); + // Find out the input fusion instructions whose only consumer is `consumer`. + // This guarantees that fusing these candidates will never create cycles, as + // there is no back edge. + if (IsReduceInputFusion(*predecessor) && + IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { + fusion_instr_set.insert(predecessor); + } + } + + std::vector fusion_instrs; + fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(), + fusion_instr_set.end()); + + std::sort(fusion_instrs.begin(), fusion_instrs.end(), + [&](const HloInstruction* a, const HloInstruction* b) { + Shape shape_a = GetInputShapeForMultiOutputFusion(*a); + Shape shape_b = GetInputShapeForMultiOutputFusion(*b); + if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { + // Sort shapes according to dimensions, so that the same input + // shapes will be placed adjacent each other. + return CompareShapeDimsFromLeftToRight(shape_a, shape_b); + } + // Sort `fusion_instrs` according to instruction counts, because + // we'd like to fuse together computations of similar sizes. + return a->fused_instruction_count() < + b->fused_instruction_count(); + }); + + return fusion_instrs; +} + +StatusOr HorizontalInputFusionImpl::Run() { + bool changed = false; + XLA_VLOG_LINES(3, computation_->ToString()); + + // Using def-to-use order is sound since we do not modify users. + std::vector def_to_use_order = + computation_->MakeInstructionPostOrder(); + for (auto consumer : def_to_use_order) { + auto candidates = FindAndSortFusionCandidates(consumer); + if (candidates.empty()) { + continue; + } + + size_t fusion_anchor_id = 0; + for (size_t j = 1; j < candidates.size(); ++j) { + HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; + HloInstruction* fused = candidates[j]; + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { + VLOG(3) << "Fuse " << fused->ToString() << " into " + << fusion_anchor->ToString(); + fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); + changed = true; + } else { + // Update the `fusion_anchor_id` since `fused` is either not + // compatible or not beneficial to be fused with current fusion anchor. + VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused."; + fusion_anchor_id = j; + } + } + } + + return changed; +} + +} // namespace + +StatusOr GpuHorizontalInputFusion::RunOnComputation( + HloComputation* computation) { + HorizontalInputFusionImpl horizontal_fusion_impl(computation); + return horizontal_fusion_impl.Run(); +} + +StatusOr GpuHorizontalInputFusion::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Run horizontal input fusion."; + for (auto* comp : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp)); + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h new file mode 100644 index 00000000000..85313d03412 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace gpu { + +// This optimization pass horizontally fuses kInput fusions to both reduce the +// kernel launch overhead and increase parallelism degree. See +// GpuHorizontalFusion for general description and motivation about horizontal +// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// with kInput fusions. +// +// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// to search the fusion candidates while avoiding creating cycles. That is, +// we simply search for fusion candidates by looking for instructions whose +// outputs are all consumed by the same instruction. This catches the typical +// target cases; often, the candidate instructions are just consumed by the +// ROOT tuple of the entry computation. +class GpuHorizontalInputFusion : public HloModulePass { + public: + GpuHorizontalInputFusion() {} + + absl::string_view name() const override { + return "gpu_horizontal_input_fusion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation*); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc new file mode 100644 index 00000000000..96e46fe723c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -0,0 +1,217 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HorizontalInputFusionTest : public GpuCodegenTest {}; + +TEST_F(HorizontalInputFusionTest, BasicTest) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2 + ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2) + } +)") + .ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); + + const HloInstruction* entry_root = + module->entry_computation()->root_instruction(); + EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())), + (op::GetTupleElement(op::Fusion())))); + + const HloInstruction* fusion = entry_root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +// TODO(b/171227713): Re-enable once fixed. +TEST_F(HorizontalInputFusionTest, DISABLED_ManyInputFusions) { + auto module = CreateNewVerifiedModule(); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + HloComputation::Builder builder(TestName()); + std::vector var_outs; + auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024}); + auto output_shape = ShapeUtil::MakeShape(F32, {1024}); + for (int64 i = 0; i < 130; ++i) { + // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) -> + // f32[1024] { + // %param_0 = f32[1024,1024]{1,0} parameter(0) + // %param_1 = f32[] parameter(1) + // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1), + // dimensions={} + // %multiply = f32[1024,1024]{1,0} + // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0} + // %broadcast) + // %constant0 = f32[] constant(0) + // ROOT %reduce = f32[1024]{0} + // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0), + // dimensions={1}, to_apply=%add + // } + HloInstruction* param_var_in = builder.AddInstruction( + HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in")); + HloInstruction* param_alpha = + builder.AddInstruction(HloInstruction::CreateParameter( + i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha")); + auto alpha_broadcasted = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, param_alpha, {})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted)); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, mul, const0, {1}, reduce_computation)); + var_outs.push_back(reduce); + } + builder.AddInstruction(HloInstruction::CreateTuple(var_outs)); + module->AddEntryComputation(builder.Build()); + + // Verify that horizontal fusion is kicked in. Check that there are multiple + // `reduce` instructions fused into the same fusion. 6 is just a randomly + // picked number as we don't exactly know how large the fusion will be + // created due to the `FusionWouldBeTooLarge` constraint. + CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", + /*match_optimized_ir=*/false); + + // Testing with the entire gpu optimization pipeline. + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) { + // This tests the below pattern. One known issue is that gtes (to fusions) can + // be removed after their producer fusions are merged. In the below case, gte2 + // and gte6 will be gone if Fusion2 is fused into Fusion1. + // + // Fusion1 Fusion2 + // | | | | + // | gte1 gte2 | + // | | | | + // | Fusion3 | + // | | | | + // gte3 gte4 gte5 gte6 + // \ | | / + // =====ROOT===== + // + auto module = ParseAndReturnVerifiedModule(R"( + HloModule MultiOutputFusionTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.3 { + arg.0 = f16[1024]{0} parameter(0) + arg.1 = f16[1024]{0} parameter(1) + add.0 = f16[1024] add(arg.0, arg.1) + mul.0 = f16[1024] multiply(arg.0, arg.1) + ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0) + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2 + gte.3 = f16[] get-tuple-element(fusion.1), index=0 + gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1 + gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1 + gte.6 = f16[] get-tuple-element(fusion.2), index=0 + fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2), + kind=kLoop, calls=fused_computation.3 + gte.4 = f16[1024] get-tuple-element(fusion.3), index=0 + gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1 + ROOT tuple.1 = (f16[], f16[1024]{0}, f16[], f16[1024]{0}) + tuple(gte.3, gte.4, gte.5, gte.6) + } +)") + .ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 577c7eed6c4..9d1e0533a91 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/env_var.h" @@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) { return true; } -bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, - const HloInstruction& consumer) { - return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { - if (user->opcode() == HloOpcode::kGetTupleElement) { - // Skip GTE. - return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { - // `user` is `consumer`. - return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; - } - }); -} - // Returns whether `instr` is a profitable candidate to be horizontally fused. // Since the primary benefit of horizontal fusion comes from reducing the // kernel launch overhead, we want to exclude the instructions with diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d33474f83c2..0a6009d7462 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -38,6 +38,7 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -87,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -203,7 +205,7 @@ StatusOr GetAllocationSliceForMlir( } StatusOr> GetMlirBufferSlices( - mlir::Operation* op, mlir::OperandRange operands, + mlir::Operation* op, mlir::ValueRange operands, absl::Span allocations) { const auto buffer_is_written = [op](mlir::Value operand) { llvm::SmallVector effects; @@ -227,6 +229,55 @@ StatusOr> GetMlirBufferSlices( return slices; } +bool BinarySearchDenseElementsAttr(::mlir::DenseIntElementsAttr elements, + int64 v) { + ::mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true); + return std::binary_search( + elements.begin(), elements.end(), value, + [](const ::mlir::APInt& x, const ::mlir::APInt& y) { return x.slt(y); }); +} + +// Returns true if the fusion contains any instruction that is likely +// translated to complex LLVM IR, such as loops, and prevent vectorization. +bool MayPreventVectorization(const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kFusion) { + return absl::c_any_of(hlo.fused_instructions_computation()->instructions(), + [](const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + case HloOpcode::kSin: + case HloOpcode::kCos: + case HloOpcode::kPower: + case HloOpcode::kAtan2: + return true; + default: + return false; + } + }); + } else if (hlo.IsElementwise()) { + // Unfused elementwise operations are usually memory bound, unroll them. + switch (hlo.opcode()) { + // The following elementwise operation implementations contain branches. + // LLVM vectorizer doesn't work in that case. + // The unrolled code is faster when it isn't vectorized. + case HloOpcode::kSin: + case HloOpcode::kCos: + case HloOpcode::kPower: + case HloOpcode::kAtan2: + return true; + default: + return false; + } + } else if (hlo.opcode() == HloOpcode::kReduce && hlo.shape().IsArray()) { + // TODO: check if the to_apply() attribute contains instruction + // that break LLVM vectorization. + return false; + } + return true; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -405,6 +456,62 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, return b->getInt32Ty(); } +// The same as GetIndexTypeForKernel, but works with MLIR ops. +llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op, + int64 launch_size, + llvm::IRBuilder<>* b) { + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, + const ShapeIndex& /*index*/) { + if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = b->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + for (auto result : op->getResults()) { + if (!shape_in_range(TypeToShape(result.getType()))) { + return i64_ty; + } + } + + auto hlo_shape_in_range = [&](mlir::Value operand) -> bool { + return shape_in_range(TypeToShape(operand.getType())); + }; + + // Check the size of input tensors + if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (auto fusion = mlir::dyn_cast(op)) { + auto result = fusion.region().walk([&](mlir::Operation* op) { + for (mlir::Value result : op->getResults()) { + if (!hlo_shape_in_range(result)) { + return mlir::WalkResult::interrupt(); + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return i64_ty; + } + } + + return b->getInt32Ty(); +} + // Gets the input shape of the ROOT slices, which will be used as the kernel // launch dims. The slice input fusion requires the input shapes of the ROOT // slices to be the same although the (slice) output shapes can be different. @@ -703,6 +810,202 @@ Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { return ThunkEmitter(this).HandleTriangularSolve(hlo); } +// Convert the following form of fusion region: +// fusion() { +// %0 = tensor_load %external_memref0 +// %1 = tensor_load %external_memref1 +// ... +// tensor_store %ret, %external_memref2 +// } +// to +// fusion(%external_memref0, %external_memref1) (^bb(%0, %1) { +// ... +// mhlo.return %ret +// }) +// +// So that it's suitable for MHLO -> XLA HLO conversion. +// This function won't be needed once ElementalIrEmitter migrates to take MHLO +// instead. +static Status ProcessFusionForConversion(mlir::Region* region, + std::vector* operands, + std::vector* outputs, + std::vector* operand_shapes) { + std::vector loads; + std::vector stores; + + region->walk([&](mlir::TensorLoadOp load) { + if (load.memref().getParentRegion() != region) { + loads.push_back(load); + } + }); + + region->walk([&](mlir::TensorStoreOp store) { + if (store.memref().getParentRegion() != region) { + stores.push_back(store); + } + }); + + for (auto load : loads) { + auto arg = region->addArgument(load.getType()); + load.replaceAllUsesWith(arg); + operands->push_back(load.memref()); + Shape shape = TypeToShape(load.getType()); + auto attr = mlir::GetLayoutFromMlirHlo(load); + if (attr) { + std::vector minor_to_major; + absl::c_transform( + attr, std::back_inserter(minor_to_major), + std::function(&llvm::APInt::getZExtValue)); + *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + } else { + *shape.mutable_layout() = + LayoutUtil::MakeDescendingLayout(load.getType().getShape().size()); + } + operand_shapes->push_back(shape); + load.erase(); + } + + std::vector returned_values; + for (auto store : stores) { + returned_values.push_back(store.tensor()); + outputs->push_back(store.memref()); + store.erase(); + } + + region->back().back().erase(); + auto b = mlir::OpBuilder::atBlockEnd(®ion->back()); + auto loc = returned_values[0].getLoc(); + b.create(loc, returned_values); + return Status::OK(); +} + +// Similar to the general GetMlirBufferSlices, but it's specific to fusion, +// since fusion doesn't have any ODS operands and memory side-effect +// annotations. +static StatusOr> CreateFusionSlices( + absl::Span fusion_operands, + absl::Span fusion_outputs, + absl::Span operand_shapes, const Shape& output_shape, + const BufferAssignment& buffer_assignment) { + absl::Span allocations( + buffer_assignment.Allocations()); + + std::vector slices; + for (int i = 0; i < fusion_operands.size(); i++) { + mlir::Value operand = fusion_operands[i]; + MlirBufferSlice slice; + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + GetAllocationSliceForMlir(operand, allocations)); + slice.shape = operand_shapes.at(i); + slices.push_back(slice); + } + for (int i = 0; i < fusion_outputs.size(); i++) { + mlir::Value output = fusion_outputs[i]; + MlirBufferSlice slice; + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + GetAllocationSliceForMlir(output, allocations)); + slice.written = true; + if (output_shape.IsTuple()) { + slice.shape = output_shape.tuple_shapes(i); + } else { + slice.shape = output_shape; + } + slices.push_back(slice); + } + + return slices; +} + +// TODO(timshen): update the comment once the HandleFusion code path deleted. +// +// This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the +// subclass. The logic is de-virtualized and less scattered. +Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input, + const Shape& output_shape, + int unroll_factor) { + auto fusion = mlir::cast(input.op); + std::string name = mlir::GetNameFromLoc(fusion.getLoc()); + + std::vector fusion_operands; + std::vector fusion_outputs; + std::vector operand_shapes; + TF_RETURN_IF_ERROR(ProcessFusionForConversion( + &fusion.region(), &fusion_operands, &fusion_outputs, &operand_shapes)); + TF_ASSIGN_OR_RETURN( + std::vector slices, + CreateFusionSlices(fusion_operands, fusion_outputs, operand_shapes, + output_shape, + ir_emitter_context_->buffer_assignment())); + slices.push_back(input.extra_slice); + + std::vector ir_arrays; + Thunk* kernel_thunk; + { + std::unique_ptr kernel_thunk_ptr = + BuildKernelThunkForMlir(name, input.thunk_info, slices, &ir_arrays); + kernel_thunk = kernel_thunk_ptr.get(); + thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr)); + } + + TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, + GetOrCreateSubComputationFromRegion(&fusion.region())); + + CHECK_EQ(fusion_operands.size(), fused_computation->num_parameters()); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + *fused_computation->parameter_instruction(i) + ->mutable_shape() + ->mutable_layout() = slices[i].shape.layout(); + } + + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + + FusedIrEmitter fused_emitter( + [&] { + auto operand_ir_arrays = + absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size()); + return std::vector(operand_ir_arrays.begin(), + operand_ir_arrays.end()); + }, + &elemental_emitter); + TF_RETURN_IF_ERROR( + fused_computation->root_instruction()->Accept(&fused_emitter)); + + auto element_generator = fused_emitter.GetRootGenerator(); + Shape element_shape = TypeToShape(fusion_outputs[0].getType()); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); + auto output_arrays = absl::MakeSpan(ir_arrays).subspan(fusion_operands.size(), + fusion_outputs.size()); + llvm::Type* index_type = GetIndexTypeForKernelFromMlir( + fusion, launch_dimensions.launch_bound(), &b_); + + if (fusion_outputs.size() > 1) { + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing + // register pressure, since we touch threadIdx.x and blockIdx.x at the + // beginning of the kernel *anyway*. + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(ir_arrays.back(), output_arrays, &b_); + }); + // For multioutput fusion, we need to emit each operand and the root. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, + launch_dimensions, &b_, + unroll_factor) + .EmitLoop(name, index_type)); + } else { + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays[0], + launch_dimensions, &b_, + unroll_factor) + .EmitLoop(name, index_type)); + } + + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); if (fusion->IsInputFusion()) { @@ -744,12 +1047,30 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { GetGeneratorForOperandIrArrays(fusion), &scatter_elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); - TF_RETURN_IF_ERROR(EmitScatter( - thunks.back().get(), root, - /*scatter_indices_gen=*/ - scatter_fused_emitter.GetGenerator(root->operand(1)), - /*updates_gen=*/ - scatter_fused_emitter.GetGenerator(root->operand(2)))); + CHECK_EQ(root->parent()->FusionInstruction(), fusion); + + TF_ASSIGN_OR_RETURN( + const auto dim_numbers, + lhlo_scratch_emitter_.GetScatterDimensionNumbers(root)); + + ScatterDescriptor desc; + desc.name = IrName(root); + desc.operand_shape = root->operand(0)->shape(); + desc.scatter_indices_shape = root->operand(1)->shape(); + desc.updates_shape = root->operand(2)->shape(); + desc.dim_numbers = dim_numbers; + desc.unique_indices = root->unique_indices(); + desc.update_computation = root->called_computations()[0]; + desc.output = GetIrArray(*fusion, *fusion); + desc.scatter_indices_gen = + scatter_fused_emitter.GetGenerator(root->operand(1)); + desc.updates_gen = + scatter_fused_emitter.GetGenerator(root->operand(2)); + desc.get_index_type = [&](int64 launch_size) { + return GetIndexTypeForKernel(root, launch_size, &b_); + }; + + TF_RETURN_IF_ERROR(EmitScatter(desc, thunks.back().get())); } AddThunkToThunkSequence(absl::make_unique( GetThunkInfo(fusion), std::move(thunks))); @@ -823,7 +1144,22 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - return IrEmitter::HandleFusion(fusion); + int unroll_factor = 1; + if (!MayPreventVectorization(*fusion)) { + unroll_factor = ComputeMaxUnrollFactor(fusion); + } + + MlirEmitterInput input; + TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_.EmitFusionOp(fusion)); + const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); + auto& slice = input.extra_slice; + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + buffer_assignment.GetUniqueSlice(fusion, {})); + slice.written = true; + slice.shape = fusion->shape(); + input.thunk_info = GetThunkInfo(fusion); + + return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor); } Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { @@ -1177,61 +1513,118 @@ Status IrEmitterUnnested::HandleRngGetAndUpdateState( } Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { - const HloInstruction* operand = scatter->operand(0); - const HloInstruction* scatter_indices = scatter->operand(1); - const HloInstruction* updates = scatter->operand(2); + MlirEmitterInput result; + + TF_ASSIGN_OR_RETURN(auto scatter_op, + lhlo_scratch_emitter_.EmitScatterOp(scatter)); + result.op = scatter_op; + result.thunk_info = GetThunkInfo(scatter); + return EmitScatterFromMlir(result); +} + +Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) { std::vector> thunks; + absl::Span allocations( + ir_emitter_context_->buffer_assignment().Allocations()); + + ::mlir::lmhlo::ScatterOp scatter_op = + ::mlir::cast<::mlir::lmhlo::ScatterOp>(mlir_input.op); + + TF_ASSIGN_OR_RETURN( + auto operand_buffer, + GetAllocationSliceForMlir(scatter_op.operand(), allocations)); + TF_ASSIGN_OR_RETURN( + auto output_buffer, + GetAllocationSliceForMlir(scatter_op.output(), allocations)); + // Copy the operand into the output if it's not the same buffer already. - auto operand_buffer = GetAllocationSlice(*operand); - auto destination_buffer = GetAllocationSlice(*scatter); - if (operand_buffer != destination_buffer) { + if (operand_buffer != output_buffer) { thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/operand_buffer, - /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()))); + /*destination_buffer=*/output_buffer, + /*mem_size=*/ + ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType())))); } - thunks.push_back( - BuildKernelThunk(scatter, - /*implements_whole_instruction=*/thunks.empty())); + // Create MLIR buffer slice info for all operands except the first one + // (`operand`). The code generated for scatter below assumes that the input + // operand is already copied into the output, so does not use it in codegen. + TF_ASSIGN_OR_RETURN( + std::vector operand_slices, + GetMlirBufferSlices(scatter_op, scatter_op.getOperands().drop_front(), + allocations)); + + std::string name = mlir::GetNameFromLoc(scatter_op.getLoc()); + std::vector ir_arrays; + thunks.push_back(BuildKernelThunkForMlir(name, mlir_input.thunk_info, + operand_slices, &ir_arrays)); + CHECK_EQ(ir_arrays.size(), 3); + const IrArray& scatter_indices = ir_arrays[0]; + const IrArray& updates = ir_arrays[1]; + const IrArray& output = ir_arrays[2]; + + auto get_index_type = [&](int64 launch_size) { + return GetIndexTypeForKernelFromMlir(scatter_op, launch_size, &b_); + }; TF_RETURN_IF_ERROR(EmitScatter( - thunks.back().get(), scatter, + thunks.back().get(), scatter_op, output, /*scatter_indices_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*scatter_indices, *scatter) - .EmitReadArrayElement(index, &b_, "scatter_index"); + [&](const IrArray::Index& index) { + return scatter_indices.EmitReadArrayElement(index, &b_, + "scatter_index"); }, /*updates_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*updates, *scatter) - .EmitReadArrayElement(index, &b_, "update"); - })); + [&](const IrArray::Index& index) { + return updates.EmitReadArrayElement(index, &b_, "update"); + }, + /* get_index_type=*/ + get_index_type)); // Elide the sequential thunk if there's no copy. if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(scatter), std::move(thunks))); + mlir_input.thunk_info, std::move(thunks))); } return Status::OK(); } Status IrEmitterUnnested::EmitScatter( - Thunk* thunk, HloInstruction* scatter, + Thunk* thunk, mlir::lmhlo::ScatterOp scatter, + const llvm_ir::IrArray& output, const llvm_ir::ElementGenerator& scatter_indices_gen, - const llvm_ir::ElementGenerator& updates_gen) { - const HloInstruction* operand = scatter->operand(0); - const HloInstruction* scatter_indices = scatter->operand(1); - const HloInstruction* updates = scatter->operand(2); - const ScatterDimensionNumbers& dim_numbers = - scatter->scatter_dimension_numbers(); - CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + const llvm_ir::ElementGenerator& updates_gen, + std::function get_index_type) { + const Shape operand_shape = TypeToShape(scatter.operand().getType()); + CHECK( + ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape)); + TF_ASSIGN_OR_RETURN( + const HloComputation* update_computation, + GetOrCreateSubComputationFromRegion(&scatter.update_computation())); + + ScatterDescriptor desc; + desc.name = mlir::GetNameFromLoc(scatter.getLoc()); + desc.operand_shape = operand_shape; + desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType()); + desc.updates_shape = TypeToShape(scatter.updates().getType()); + desc.dim_numbers = scatter.scatter_dimension_numbers(); + desc.unique_indices = scatter.unique_indices(); + desc.update_computation = update_computation; + desc.output = output; + desc.scatter_indices_gen = scatter_indices_gen; + desc.updates_gen = updates_gen; + desc.get_index_type = get_index_type; + return EmitScatter(desc, thunk); +} + +Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc, + Thunk* thunk) { auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { std::vector raw_window_multidim; std::vector input_scatter_multidim; @@ -1241,22 +1634,25 @@ Status IrEmitterUnnested::EmitScatter( for (int64 i = 0, e = index.size(); i != e; ++i) { // For window indices also remember the window size, this comes in handy // later. - if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(), + i)) { raw_window_multidim.push_back(index[i]); - raw_window_bounds.push_back(updates->shape().dimensions(i)); + raw_window_bounds.push_back(desc.updates_shape.dimensions(i)); } else { input_scatter_multidim.push_back(index[i]); } } DCHECK_EQ(raw_window_multidim.size(), - dim_numbers.update_window_dims_size()); + desc.dim_numbers.update_window_dims().size()); // Apply inserted_window_dims to the window dimensions. int64 raw_window_multidim_idx = 0; std::vector input_window_multidim; std::vector input_window_bounds; - for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) { - if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + + for (int64 i = 0, e = desc.operand_shape.rank(); i != e; ++i) { + if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(), + i)) { input_window_bounds.push_back(1); // Trivial dimension. input_window_multidim.push_back(index.GetConstantWithIndexType(0)); } else { @@ -1267,14 +1663,15 @@ Status IrEmitterUnnested::EmitScatter( ++raw_window_multidim_idx; } } - DCHECK_EQ(input_window_multidim.size(), operand->shape().rank()); + DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank()); // Insert a 1 dimension at the end if index_vector_dim requests one. - Shape scatter_indices_shape = scatter_indices->shape(); - if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { - scatter_indices_shape.add_dimensions(1); - scatter_indices_shape.mutable_layout()->add_minor_to_major( - dim_numbers.index_vector_dim()); + Shape scatter_indices_shape_fixed = desc.scatter_indices_shape; + if (desc.dim_numbers.index_vector_dim().getInt() == + desc.scatter_indices_shape.rank()) { + scatter_indices_shape_fixed.add_dimensions(1); + scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major( + desc.dim_numbers.index_vector_dim().getInt()); } // Now load the indices corresponding to the current window from @@ -1282,23 +1679,27 @@ Status IrEmitterUnnested::EmitScatter( std::vector raw_scatter_index_multidim = input_scatter_multidim; raw_scatter_index_multidim.insert( - raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), + raw_scatter_index_multidim.begin() + + desc.dim_numbers.index_vector_dim().getInt(), nullptr); llvm::Value* is_in_bounds = b_.getTrue(); - for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + for (int64 i = 0, + e = desc.dim_numbers.scatter_dims_to_operand_dims().size(); i != e; ++i) { // Our index is stored along index_vector_dim, insert that into the lookup // index into scatter_indices. - raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = + raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] = index.GetConstantWithIndexType(i); llvm_ir::IrArray::Index raw_scatter_index_index( - raw_scatter_index_multidim, scatter_indices_shape, index.GetType()); + raw_scatter_index_multidim, scatter_indices_shape_fixed, + index.GetType()); - int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + int64 operand_dim = + desc.dim_numbers.scatter_dims_to_operand_dims().getValue(i); TF_ASSIGN_OR_RETURN( llvm::Value* const loaded_scatter_index, - scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( - scatter_indices_shape, scatter_indices->shape(), &b_))); + desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( + scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); // And add the index to our window index. This yields the output index. llvm::Value* casted_scatter_index = IntCast(loaded_scatter_index, index.GetType(), @@ -1308,7 +1709,7 @@ Status IrEmitterUnnested::EmitScatter( input_window_multidim[operand_dim] = dim_offset; // Also do the bounds check now. - int64 max_index = operand->shape().dimensions(operand_dim) - + int64 max_index = desc.operand_shape.dimensions(operand_dim) - input_window_bounds[operand_dim] + 1; // is_in_bounds = index >= 0 && index < dim_size-window_size+1 // --> index u< dim_size-window_size+1 @@ -1322,25 +1723,23 @@ Status IrEmitterUnnested::EmitScatter( llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); // All done, now just read from the calculated input from the window, and do // an atomic store to the calculated location in the output. - HloInstruction* output_hlo = - scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter; llvm_ir::IrArray::Index input_window_index( - input_window_multidim, output_hlo->shape(), index.GetType()); + input_window_multidim, desc.output.GetShape(), index.GetType()); llvm::Value* output_address = - GetIrArray(*output_hlo, *output_hlo) - .EmitArrayElementAddress(input_window_index, &b_); + desc.output.EmitArrayElementAddress(input_window_index, &b_); llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(updates->shape().element_type(), + llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(), module_), "input_address", &b_); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + desc.updates_gen(index)); Store(input_ir_value, input_address); - if (!scatter->unique_indices()) { + if (!desc.unique_indices) { return EmitAtomicOperationForNestedComputation( - *scatter->to_apply(), output_address, input_address); + *desc.update_computation, output_address, input_address); } else { - return EmitCallToNestedComputation(*scatter->to_apply(), + return EmitCallToNestedComputation(*desc.update_computation, {output_address, input_address}, output_address); } @@ -1350,31 +1749,52 @@ Status IrEmitterUnnested::EmitScatter( // also do one kernel per window instead if bounds checks turn out to be a // bottleneck. LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - updates->shape(), ir_emitter_context_->gpu_device_info()); + desc.updates_shape, ir_emitter_context_->gpu_device_info()); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape, launch_dimensions, &b_) - .EmitLoop(IrName(scatter), - GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), - &b_)); + .EmitLoop(desc.name, + desc.get_index_type(launch_dimensions.launch_bound())); } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +// This transformation should be migrated off. See b/171334474. StatusOr IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) { std::unique_ptr& module = scratch_nested_computations_[region]; if (module == nullptr) { xla::XlaComputation xla_computation; - TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation)); + mlir::MlirToHloConversionOptions options; + options.propagate_layouts = true; + TF_RETURN_IF_ERROR( + ConvertRegionToComputation(region, &xla_computation, options)); + TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( module, HloModule::CreateFromProto(xla_computation.proto(), HloModuleConfig(program_shape))); + + // Post-process the generated computation: + // * Sanitize constant names, so that they can be used as LLVM global + // symbols. + // * Propagate layouts for tuple types. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConstant) { + instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(*instr)); + } + if (instr->shape().IsTuple()) { + TF_ASSIGN_OR_RETURN(*instr->mutable_shape(), + ShapeInference::InferVariadicOpShape( + instr->opcode(), instr->operands())); + } + } + } } return module->entry_computation(); } @@ -2374,51 +2794,6 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( return Status::OK(); } -namespace { - -// Returns true if the fusion contains any instruction that is likely -// translated to complex LLVM IR, such as loops, and prevent vectorization. -bool MayPreventVectorization(const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kFusion) { - return absl::c_any_of(hlo.fused_instructions_computation()->instructions(), - [](const HloInstruction* instr) { - switch (instr->opcode()) { - case HloOpcode::kReduceWindow: - case HloOpcode::kSort: - case HloOpcode::kDot: - case HloOpcode::kSin: - case HloOpcode::kCos: - case HloOpcode::kPower: - case HloOpcode::kAtan2: - return true; - default: - return false; - } - }); - } else if (hlo.IsElementwise()) { - // Unfused elementwise operations are usually memory bound, unroll them. - switch (hlo.opcode()) { - // The following elementwise operation implementations contain branches. - // LLVM vectorizer doesn't work in that case. - // The unrolled code is faster when it isn't vectorized. - case HloOpcode::kSin: - case HloOpcode::kCos: - case HloOpcode::kPower: - case HloOpcode::kAtan2: - return true; - default: - return false; - } - } else if (hlo.opcode() == HloOpcode::kReduce && hlo.shape().IsArray()) { - // TODO: check if the to_apply() attribute contains instruction - // that break LLVM vectorization. - return false; - } - return true; -} - -} // namespace - Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) { int unroll_factor = 1; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 5cc5e206167..40546dbb50d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" namespace xla { @@ -51,7 +52,7 @@ struct HloBufferSlice : public BufferSlice { struct MlirBufferSlice : public BufferSlice { // The buffer is modified by the kernel. - bool written; + bool written = false; Shape shape; }; @@ -59,6 +60,15 @@ struct MlirBufferSlice : public BufferSlice { struct MlirEmitterInput { mlir::Operation* op; Thunk::ThunkInfo thunk_info; + + // A field to allow plumbing extra information that BufferAssignment has, but + // LMHLO/MLIR representation does not have. Specifically, this is for passing + // the allocated buffer for tuple outputs (an array of pointers to tuple + // elements). + // + // TODO(timshen): We need a corresponding construct in LMHLO to represent + // this, aka an array of pointers to different memrefs. Once we have that, we + // can merge this information back to LMHLO graph and remove this field. MlirBufferSlice extra_slice; }; @@ -148,6 +158,8 @@ class IrEmitterUnnested : public IrEmitter, Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; + Status EmitLoopFusionFromMlir(MlirEmitterInput input, + const Shape& output_shape, int unroll_factor); Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -158,6 +170,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleRng(HloInstruction* random) override; Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; Status HandleScatter(HloInstruction* scatter) override; + Status EmitScatterFromMlir(MlirEmitterInput mlir_input); Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status EmitSortFromMlir(MlirEmitterInput mlir_input); @@ -407,16 +420,38 @@ class IrEmitterUnnested : public IrEmitter, const llvm_ir::IrArray::Index& slice_input_index); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in - // the process. `scatter` may be fused, scatter indices are taken from - // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is - // expected to have the operand values in it already. If unique_indices - // is false, we will use an atomic update. Using true for unique_indices - // behaves properly only when it is guaranteed that the indices to be - // updated do not overlap. The caller is responsible for ensuring this is - // the case. - Status EmitScatter(Thunk* thunk, HloInstruction* scatter, + // the process. Scatter indices are taken from `scatter_indices_gen`, updates + // from `updates_gen`. The output buffer is expected to have the operand + // values in it already. If unique_indices is false, we will use an atomic + // update. Using true for unique_indices behaves properly only when it is + // guaranteed that the indices to be updated do not overlap. The caller is + // responsible for ensuring this is the case. + Status EmitScatter(Thunk* thunk, mlir::lmhlo::ScatterOp scatter, + const llvm_ir::IrArray& output, const llvm_ir::ElementGenerator& scatter_indices_gen, - const llvm_ir::ElementGenerator& updates_gen); + const llvm_ir::ElementGenerator& updates_gen, + std::function get_index_type); + + // Structure describing a scatter operation for IR emission. + // TODO(jurahul): Migrate element generators to use MLIR. + // Migrate update_computation to be an MLIR Region. + struct ScatterDescriptor { + std::string name; + Shape operand_shape; + Shape scatter_indices_shape; + Shape updates_shape; + mlir::mhlo::ScatterDimensionNumbers dim_numbers; + bool unique_indices; + const HloComputation* update_computation; + llvm_ir::IrArray output; + llvm_ir::ElementGenerator scatter_indices_gen; + llvm_ir::ElementGenerator updates_gen; + std::function get_index_type; + }; + + // Emits code for an in-place scatter using the provided scatter operation + // description. + Status EmitScatter(const ScatterDescriptor& desc, Thunk* thunk); // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 04af67a70b9..51583117706 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -149,7 +149,7 @@ std::unique_ptr GetTargetMachine( } llvm::TargetOptions target_options = - llvm::codegen::InitTargetOptionsFromCodeGenFlags(); + llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple()); // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -225,7 +225,7 @@ void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) { // for the NVPTX target. string EmitModuleToPTX(llvm::Module* module, llvm::TargetMachine* target_machine) { - std::string ptx; // need a std::string instead of a ::string. + std::string ptx; { llvm::raw_string_ostream stream(ptx); llvm::buffer_ostream pstream(stream); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 4ee78b0874c..fa73ac261f8 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -148,6 +149,16 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( << " would be too large of a fusion."; continue; } + // Make sure the emitter can codegen the fusion op efficiently. We currently + // can have exponential time/memory requirements for emitting certain fusion + // ops, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once fixed in the emitter. + if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) { + VLOG(3) << "Fusion of " << producer->name() << " into " + << consumer->name() + << " would result in overly large code duplication."; + continue; + } fusion_candidates.push_back(consumer); } return fusion_candidates; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 283eb30d15f..6cb66290a9a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -909,5 +909,165 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { EXPECT_EQ(2, CountMultiOutputFusions(module.get())); } +TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +and.reduce_sub_computation { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) +} + +fused_computation.1 { + param_4.658 = f32[2,20,256]{2,0,1} parameter(4) + slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]} + constant.6847 = s32[] constant(0) + broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={} + param_9.415 = s32[3]{0} parameter(9) + compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE + constant.6846 = pred[] constant(true) + reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={} + param_5.528 = f32[2,512]{1,0} parameter(5) + slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]} + bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384) + constant.5418 = f32[] constant(0) + broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={} + select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227) + add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173) + param_0.299 = s32[] parameter(0) + constant.5157 = s32[] constant(11) + dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299) + slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]} + constant.6800 = s32[] constant(0) + broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={} + param_8.484 = s32[3]{0} parameter(8) + compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE + constant.6798 = pred[] constant(true) + reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={} + param_3.1169 = f32[2,512]{1,0} parameter(3) + slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]} + bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382) + select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227) + add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172) + constant.5154 = s32[] constant(10) + dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299) + slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]} + constant.6794 = s32[] constant(0) + broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={} + param_7.478 = s32[3]{0} parameter(7) + compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE + constant.6793 = pred[] constant(true) + reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={} + param_2.1685 = f32[2,512]{1,0} parameter(2) + slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]} + bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380) + select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227) + add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171) + constant.5153 = s32[] constant(9) + dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299) + slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]} + constant.6788 = s32[] constant(0) + broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={} + param_6.495 = s32[3]{0} parameter(6) + compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE + constant.6786 = pred[] constant(true) + reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={} + param_1.1408 = f32[2,512]{1,0} parameter(1) + slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]} + bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378) + select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227) + add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170) + constant.5152 = s32[] constant(8) + ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299) +} + +fused_computation.2 { + param_4.655 = f32[2,20,256]{2,0,1} parameter(4) + slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]} + param_6.483 = pred[] parameter(6) + broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={} + param_5.525 = f32[2,512]{1,0} parameter(5) + slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]} + bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368) + constant.5415 = f32[] constant(0) + broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={} + select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225) + add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161) + param_0.265 = s32[] parameter(0) + constant.5151 = s32[] constant(7) + dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265) + slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]} + constant.6782 = s32[] constant(0) + broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={} + param_9.391 = s32[3]{0} parameter(9) + compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE + constant.6781 = pred[] constant(true) + reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={} + param_3.1167 = f32[2,512]{1,0} parameter(3) + slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]} + bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366) + select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225) + add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160) + constant.5150 = s32[] constant(6) + dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265) + slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]} + constant.6776 = s32[] constant(0) + broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={} + param_8.464 = s32[3]{0} parameter(8) + compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE + constant.6775 = pred[] constant(true) + reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={} + param_2.1684 = f32[2,512]{1,0} parameter(2) + slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]} + bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364) + select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225) + add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159) + constant.5149 = s32[] constant(5) + dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265) + slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]} + constant.6770 = s32[] constant(0) + broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={} + param_7.458 = s32[3]{0} parameter(7) + compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE + constant.6769 = pred[] constant(true) + reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={} + param_1.1405 = f32[2,512]{1,0} parameter(1) + slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]} + bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362) + select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225) + add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158) + constant.5148 = s32[] constant(4) + ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265) +} + +ENTRY main { + param_0.0 = s32[] parameter(0) + param_1.0 = f32[2,512]{1,0} parameter(1) + param_2.0 = f32[2,512]{1,0} parameter(2) + param_3.0 = f32[2,512]{1,0} parameter(3) + param_4.0 = f32[2,20,256]{2,1,0} parameter(4) + param_5.0 = f32[2,512]{1,0} parameter(5) + param_6.0 = s32[3]{0} parameter(6) + param_7.0 = s32[3]{0} parameter(7) + param_8.0 = s32[3]{0} parameter(8) + param_9.0 = s32[3]{0} parameter(9) + fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1 + param_10 = pred[] parameter(10) + fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2 + ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2) +} + )") + .ValueOrDie(); + EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index eefa4661d37..77c54e48a70 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -198,6 +198,42 @@ absl::optional CanShareBufferHint(const HloInstruction* user, return absl::nullopt; } +// Try to load ptx from files defined in the FLAGS. If successful, return true. +bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { + // If the xla_gpu_ptx_file options is set, be explicit when a file is used + // and warn when a file is not used to ease catching typo in filename. + std::string prefix = xla::FilenameFor(*module, "", *ptx); + std::string matched_filename; + for (const string& full_filename : + module->config().debug_options().xla_gpu_ptx_file()) { + // To ease comparing many PTX versions, accept different suffixes then + // the original filename. + auto filename = tensorflow::io::Basename(full_filename); + if (absl::StartsWith(filename, prefix)) { + matched_filename = full_filename; + VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename; + break; + } + } + if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 && + matched_filename.empty()) { + VLOG(0) << "RunBackend() - For module with prefix '" << prefix + << "', we did not found a PTX file to load."; + } + + if (!matched_filename.empty()) { + std::ifstream ifs(matched_filename, std::ifstream::in); + *ptx = std::string(std::istreambuf_iterator(ifs), + std::istreambuf_iterator()); + CHECK(!ptx->empty()) << "Empty or non existing PTX file: " + << matched_filename; + return true; + } + return false; +} + +} // namespace + // Prints a warning if the ptx->sass JIT in the driver has known bugs. // // Using such a driver only a problem if we fail to use ptxas to compile our ptx @@ -238,42 +274,6 @@ void WarnIfBadDriverJITVersion() { }); } -// Try to load ptx from files defined in the FLAGS. If successful, return true. -bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { - // If the xla_gpu_ptx_file options is set, be explicit when a file is used - // and warn when a file is not used to ease catching typo in filename. - std::string prefix = xla::FilenameFor(*module, "", *ptx); - std::string matched_filename; - for (const string& full_filename : - module->config().debug_options().xla_gpu_ptx_file()) { - // To ease comparing many PTX versions, accept different suffixes then - // the original filename. - auto filename = tensorflow::io::Basename(full_filename); - if (absl::StartsWith(filename, prefix)) { - matched_filename = full_filename; - VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename; - break; - } - } - if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 && - matched_filename.empty()) { - VLOG(0) << "RunBackend() - For module with prefix '" << prefix - << "', we did not found a PTX file to load."; - } - - if (!matched_filename.empty()) { - std::ifstream ifs(matched_filename, std::ifstream::in); - *ptx = std::string(std::istreambuf_iterator(ifs), - std::istreambuf_iterator()); - CHECK(!ptx->empty()) << "Empty or non existing PTX file: " - << matched_filename; - return true; - } - return false; -} - -} // namespace - NVPTXCompiler::NVPTXCompiler() : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::kTargetTriple, nvptx::kDataLayout) {} @@ -415,7 +415,9 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( "using $PATH.", hlo_module_config); } - } else { + } else if (maybe_cubin.status().code() != + tensorflow::error::Code::UNIMPLEMENTED) { + // If unimplemented is returned, we fallback to the driver. LOG(FATAL) << "ptxas returned an error during compilation of ptx " "to sass: '" << maybe_cubin.status() << "' " diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index e69be947522..3e19b35af19 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -30,6 +30,8 @@ limitations under the License. namespace xla { namespace gpu { +void WarnIfBadDriverJITVersion(); + // NVPTXCompiler generates efficient GPU executables for NVPTX target. class NVPTXCompiler : public GpuCompiler { public: diff --git a/tensorflow/compiler/xla/service/gpu/tests/fusion.hlo b/tensorflow/compiler/xla/service/gpu/tests/fusion.hlo new file mode 100644 index 00000000000..73a56cb15ba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/fusion.hlo @@ -0,0 +1,353 @@ +// RUN: hlo_to_llvm_ir %s | FileCheck %s + +HloModule TestModule + +// CHECK-LABEL: entry: +// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [64 x float]* +// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 +// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [64 x float]* +// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0 +// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [64 x float]* +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0 +// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [64 x float]* +// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0 +// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [64 x float]* +// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0 +// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [64 x float]* +// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0 +// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [128 x [112 x [112 x [64 x half]]]]* +// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0 +// CHECK: %[[VAL_23:.*]] = bitcast i8* %[[VAL_21]] to [128 x [112 x [112 x [64 x half]]]]* +// CHECK: %[[VAL_24:.*]] = getelementptr inbounds i8, i8* %[[VAL_25:.*]], i64 0 +// CHECK: %[[VAL_26:.*]] = bitcast i8* %[[VAL_24]] to [128 x [112 x [112 x [64 x half]]]]* +// CHECK: %[[VAL_27:.*]] = getelementptr inbounds i8, i8* %[[VAL_28:.*]], i64 0 +// CHECK: %[[VAL_29:.*]] = bitcast i8* %[[VAL_27]] to [128 x [112 x [112 x [64 x half]]]]* +// CHECK: %[[VAL_30:.*]] = getelementptr inbounds i8, i8* %[[VAL_28]], i64 0 +// CHECK: %[[VAL_31:.*]] = bitcast i8* %[[VAL_30]] to [128 x [112 x [112 x [64 x half]]]]* +// CHECK: %[[VAL_32:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 +// CHECK: %[[VAL_33:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 +// CHECK: %[[VAL_34:.*]] = mul nuw nsw i32 %[[VAL_32]], 256 +// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_34]], %[[VAL_33]] +// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_35]], 25690112 +// CHECK: call void @llvm.assume(i1 %[[VAL_36]]) +// CHECK: %[[VAL_37:.*]] = mul nuw nsw i32 %[[VAL_35]], 4 +// CHECK: %[[VAL_38:.*]] = udiv i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_39:.*]] = urem i32 %[[VAL_38]], 64 +// CHECK: %[[VAL_40:.*]] = udiv i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_41:.*]] = urem i32 %[[VAL_40]], 112 +// CHECK: %[[VAL_42:.*]] = udiv i32 %[[VAL_37]], 7168 +// CHECK: %[[VAL_43:.*]] = urem i32 %[[VAL_42]], 112 +// CHECK: %[[VAL_44:.*]] = udiv i32 %[[VAL_37]], 802816 +// CHECK: %[[VAL_45:.*]] = add nuw nsw i32 %[[VAL_37]], 1 +// CHECK: %[[VAL_46:.*]] = udiv i32 %[[VAL_45]], 1 +// CHECK: %[[VAL_47:.*]] = urem i32 %[[VAL_46]], 64 +// CHECK: %[[VAL_48:.*]] = udiv i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_49:.*]] = urem i32 %[[VAL_48]], 112 +// CHECK: %[[VAL_50:.*]] = udiv i32 %[[VAL_45]], 7168 +// CHECK: %[[VAL_51:.*]] = urem i32 %[[VAL_50]], 112 +// CHECK: %[[VAL_52:.*]] = udiv i32 %[[VAL_45]], 802816 +// CHECK: %[[VAL_53:.*]] = add nuw nsw i32 %[[VAL_37]], 2 +// CHECK: %[[VAL_54:.*]] = udiv i32 %[[VAL_53]], 1 +// CHECK: %[[VAL_55:.*]] = urem i32 %[[VAL_54]], 64 +// CHECK: %[[VAL_56:.*]] = udiv i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_57:.*]] = urem i32 %[[VAL_56]], 112 +// CHECK: %[[VAL_58:.*]] = udiv i32 %[[VAL_53]], 7168 +// CHECK: %[[VAL_59:.*]] = urem i32 %[[VAL_58]], 112 +// CHECK: %[[VAL_60:.*]] = udiv i32 %[[VAL_53]], 802816 +// CHECK: %[[VAL_61:.*]] = add nuw nsw i32 %[[VAL_37]], 3 +// CHECK: %[[VAL_62:.*]] = udiv i32 %[[VAL_61]], 1 +// CHECK: %[[VAL_63:.*]] = urem i32 %[[VAL_62]], 64 +// CHECK: %[[VAL_64:.*]] = udiv i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_65:.*]] = urem i32 %[[VAL_64]], 112 +// CHECK: %[[VAL_66:.*]] = udiv i32 %[[VAL_61]], 7168 +// CHECK: %[[VAL_67:.*]] = urem i32 %[[VAL_66]], 112 +// CHECK: %[[VAL_68:.*]] = udiv i32 %[[VAL_61]], 802816 +// CHECK: %[[VAL_69:.*]] = icmp ult i32 %[[VAL_37]], 102760448 +// CHECK: br i1 %[[VAL_69]], label %[[VAL_70:.*]], label %[[VAL_71:.*]] +// CHECK: fusion.1.in_bounds-after: ; preds = %[[VAL_70]], %[[VAL_72:.*]] +// CHECK: ret void +// CHECK: fusion.1.in_bounds-true: ; preds = %[[VAL_72]] +// CHECK: %[[VAL_73:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_74:.*]] = bitcast [64 x float]* %[[VAL_14]] to float* +// CHECK: %[[VAL_75:.*]] = getelementptr inbounds float, float* %[[VAL_74]], i32 %[[VAL_73]] +// CHECK: %[[VAL_76:.*]] = load float, float* %[[VAL_75]], align 4, !invariant.load !4 +// CHECK: %[[VAL_77:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_78:.*]] = bitcast [64 x float]* %[[VAL_11]] to float* +// CHECK: %[[VAL_79:.*]] = getelementptr inbounds float, float* %[[VAL_78]], i32 %[[VAL_77]] +// CHECK: %[[VAL_80:.*]] = load float, float* %[[VAL_79]], align 4, !invariant.load !4 +// CHECK: %[[VAL_81:.*]] = fmul float %[[VAL_76]], %[[VAL_80]] +// CHECK: %[[VAL_82:.*]] = load float, float* bitcast ([4 x i8]* @0 to float*), align 4 +// CHECK: %[[VAL_83:.*]] = fmul float %[[VAL_81]], %[[VAL_82]] +// CHECK: %[[VAL_84:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_26]] to half* +// CHECK: %[[VAL_85:.*]] = getelementptr inbounds half, half* %[[VAL_84]], i32 %[[VAL_37]] +// CHECK: %[[VAL_86:.*]] = load half, half* %[[VAL_85]], align 2, !invariant.load !4 +// CHECK: %[[VAL_87:.*]] = load half, half* bitcast ([2 x i8]* @1 to half*), align 2 +// CHECK: %[[VAL_88:.*]] = fcmp ogt half %[[VAL_86]], %[[VAL_87]] +// CHECK: %[[VAL_89:.*]] = zext i1 %[[VAL_88]] to i8 +// CHECK: %[[VAL_90:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_23]] to half* +// CHECK: %[[VAL_91:.*]] = getelementptr inbounds half, half* %[[VAL_90]], i32 %[[VAL_37]] +// CHECK: %[[VAL_92:.*]] = load half, half* %[[VAL_91]], align 2, !invariant.load !4 +// CHECK: %[[VAL_93:.*]] = trunc i8 %[[VAL_89]] to i1 +// CHECK: %[[VAL_94:.*]] = select i1 %[[VAL_93]], half %[[VAL_92]], half %[[VAL_87]] +// CHECK: %[[VAL_95:.*]] = fpext half %[[VAL_94]] to float +// CHECK: %[[VAL_96:.*]] = load float, float* bitcast ([4 x i8]* @2 to float*), align 4 +// CHECK: %[[VAL_97:.*]] = fmul float %[[VAL_95]], %[[VAL_96]] +// CHECK: %[[VAL_98:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_99:.*]] = bitcast [64 x float]* %[[VAL_8]] to float* +// CHECK: %[[VAL_100:.*]] = getelementptr inbounds float, float* %[[VAL_99]], i32 %[[VAL_98]] +// CHECK: %[[VAL_101:.*]] = load float, float* %[[VAL_100]], align 4, !invariant.load !4 +// CHECK: %[[VAL_102:.*]] = fsub float %[[VAL_97]], %[[VAL_101]] +// CHECK: %[[VAL_103:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_104:.*]] = bitcast [64 x float]* %[[VAL_5]] to float* +// CHECK: %[[VAL_105:.*]] = getelementptr inbounds float, float* %[[VAL_104]], i32 %[[VAL_103]] +// CHECK: %[[VAL_106:.*]] = load float, float* %[[VAL_105]], align 4, !invariant.load !4 +// CHECK: %[[VAL_107:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_20]] to half* +// CHECK: %[[VAL_108:.*]] = getelementptr inbounds half, half* %[[VAL_107]], i32 %[[VAL_37]] +// CHECK: %[[VAL_109:.*]] = load half, half* %[[VAL_108]], align 2, !invariant.load !4 +// CHECK: %[[VAL_110:.*]] = fpext half %[[VAL_109]] to float +// CHECK: %[[VAL_111:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_112:.*]] = bitcast [64 x float]* %[[VAL_17]] to float* +// CHECK: %[[VAL_113:.*]] = getelementptr inbounds float, float* %[[VAL_112]], i32 %[[VAL_111]] +// CHECK: %[[VAL_114:.*]] = load float, float* %[[VAL_113]], align 4, !invariant.load !4 +// CHECK: %[[VAL_115:.*]] = load float, float* bitcast ([4 x i8]* @3 to float*), align 4 +// CHECK: %[[VAL_116:.*]] = fmul float %[[VAL_114]], %[[VAL_115]] +// CHECK: %[[VAL_117:.*]] = fsub float %[[VAL_110]], %[[VAL_116]] +// CHECK: %[[VAL_118:.*]] = fmul float %[[VAL_106]], %[[VAL_117]] +// CHECK: %[[VAL_119:.*]] = urem i32 %[[VAL_37]], 64 +// CHECK: %[[VAL_120:.*]] = bitcast [64 x float]* %[[VAL_2]] to float* +// CHECK: %[[VAL_121:.*]] = getelementptr inbounds float, float* %[[VAL_120]], i32 %[[VAL_119]] +// CHECK: %[[VAL_122:.*]] = load float, float* %[[VAL_121]], align 4, !invariant.load !4 +// CHECK: %[[VAL_123:.*]] = fdiv float %[[VAL_118]], %[[VAL_122]] +// CHECK: %[[VAL_124:.*]] = fsub float %[[VAL_102]], %[[VAL_123]] +// CHECK: %[[VAL_125:.*]] = fmul float %[[VAL_83]], %[[VAL_124]] +// CHECK: %[[VAL_126:.*]] = fptrunc float %[[VAL_125]] to half +// CHECK: %[[VAL_127:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_29]] to half* +// CHECK: %[[VAL_128:.*]] = getelementptr inbounds half, half* %[[VAL_127]], i32 %[[VAL_37]] +// CHECK: store half %[[VAL_126]], half* %[[VAL_128]], align 2 +// CHECK: %[[VAL_129:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_130:.*]] = bitcast [64 x float]* %[[VAL_14]] to float* +// CHECK: %[[VAL_131:.*]] = getelementptr inbounds float, float* %[[VAL_130]], i32 %[[VAL_129]] +// CHECK: %[[VAL_132:.*]] = load float, float* %[[VAL_131]], align 4, !invariant.load !4 +// CHECK: %[[VAL_133:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_134:.*]] = bitcast [64 x float]* %[[VAL_11]] to float* +// CHECK: %[[VAL_135:.*]] = getelementptr inbounds float, float* %[[VAL_134]], i32 %[[VAL_133]] +// CHECK: %[[VAL_136:.*]] = load float, float* %[[VAL_135]], align 4, !invariant.load !4 +// CHECK: %[[VAL_137:.*]] = fmul float %[[VAL_132]], %[[VAL_136]] +// CHECK: %[[VAL_138:.*]] = load float, float* bitcast ([4 x i8]* @4 to float*), align 4 +// CHECK: %[[VAL_139:.*]] = fmul float %[[VAL_137]], %[[VAL_138]] +// CHECK: %[[VAL_140:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_26]] to half* +// CHECK: %[[VAL_141:.*]] = getelementptr inbounds half, half* %[[VAL_140]], i32 %[[VAL_45]] +// CHECK: %[[VAL_142:.*]] = load half, half* %[[VAL_141]], align 2, !invariant.load !4 +// CHECK: %[[VAL_143:.*]] = load half, half* bitcast ([2 x i8]* @5 to half*), align 2 +// CHECK: %[[VAL_144:.*]] = fcmp ogt half %[[VAL_142]], %[[VAL_143]] +// CHECK: %[[VAL_145:.*]] = zext i1 %[[VAL_144]] to i8 +// CHECK: %[[VAL_146:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_23]] to half* +// CHECK: %[[VAL_147:.*]] = getelementptr inbounds half, half* %[[VAL_146]], i32 %[[VAL_45]] +// CHECK: %[[VAL_148:.*]] = load half, half* %[[VAL_147]], align 2, !invariant.load !4 +// CHECK: %[[VAL_149:.*]] = trunc i8 %[[VAL_145]] to i1 +// CHECK: %[[VAL_150:.*]] = select i1 %[[VAL_149]], half %[[VAL_148]], half %[[VAL_143]] +// CHECK: %[[VAL_151:.*]] = fpext half %[[VAL_150]] to float +// CHECK: %[[VAL_152:.*]] = load float, float* bitcast ([4 x i8]* @6 to float*), align 4 +// CHECK: %[[VAL_153:.*]] = fmul float %[[VAL_151]], %[[VAL_152]] +// CHECK: %[[VAL_154:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_155:.*]] = bitcast [64 x float]* %[[VAL_8]] to float* +// CHECK: %[[VAL_156:.*]] = getelementptr inbounds float, float* %[[VAL_155]], i32 %[[VAL_154]] +// CHECK: %[[VAL_157:.*]] = load float, float* %[[VAL_156]], align 4, !invariant.load !4 +// CHECK: %[[VAL_158:.*]] = fsub float %[[VAL_153]], %[[VAL_157]] +// CHECK: %[[VAL_159:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_160:.*]] = bitcast [64 x float]* %[[VAL_5]] to float* +// CHECK: %[[VAL_161:.*]] = getelementptr inbounds float, float* %[[VAL_160]], i32 %[[VAL_159]] +// CHECK: %[[VAL_162:.*]] = load float, float* %[[VAL_161]], align 4, !invariant.load !4 +// CHECK: %[[VAL_163:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_20]] to half* +// CHECK: %[[VAL_164:.*]] = getelementptr inbounds half, half* %[[VAL_163]], i32 %[[VAL_45]] +// CHECK: %[[VAL_165:.*]] = load half, half* %[[VAL_164]], align 2, !invariant.load !4 +// CHECK: %[[VAL_166:.*]] = fpext half %[[VAL_165]] to float +// CHECK: %[[VAL_167:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_168:.*]] = bitcast [64 x float]* %[[VAL_17]] to float* +// CHECK: %[[VAL_169:.*]] = getelementptr inbounds float, float* %[[VAL_168]], i32 %[[VAL_167]] +// CHECK: %[[VAL_170:.*]] = load float, float* %[[VAL_169]], align 4, !invariant.load !4 +// CHECK: %[[VAL_171:.*]] = load float, float* bitcast ([4 x i8]* @7 to float*), align 4 +// CHECK: %[[VAL_172:.*]] = fmul float %[[VAL_170]], %[[VAL_171]] +// CHECK: %[[VAL_173:.*]] = fsub float %[[VAL_166]], %[[VAL_172]] +// CHECK: %[[VAL_174:.*]] = fmul float %[[VAL_162]], %[[VAL_173]] +// CHECK: %[[VAL_175:.*]] = urem i32 %[[VAL_45]], 64 +// CHECK: %[[VAL_176:.*]] = bitcast [64 x float]* %[[VAL_2]] to float* +// CHECK: %[[VAL_177:.*]] = getelementptr inbounds float, float* %[[VAL_176]], i32 %[[VAL_175]] +// CHECK: %[[VAL_178:.*]] = load float, float* %[[VAL_177]], align 4, !invariant.load !4 +// CHECK: %[[VAL_179:.*]] = fdiv float %[[VAL_174]], %[[VAL_178]] +// CHECK: %[[VAL_180:.*]] = fsub float %[[VAL_158]], %[[VAL_179]] +// CHECK: %[[VAL_181:.*]] = fmul float %[[VAL_139]], %[[VAL_180]] +// CHECK: %[[VAL_182:.*]] = fptrunc float %[[VAL_181]] to half +// CHECK: %[[VAL_183:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_29]] to half* +// CHECK: %[[VAL_184:.*]] = getelementptr inbounds half, half* %[[VAL_183]], i32 %[[VAL_45]] +// CHECK: store half %[[VAL_182]], half* %[[VAL_184]], align 2 +// CHECK: %[[VAL_185:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_186:.*]] = bitcast [64 x float]* %[[VAL_14]] to float* +// CHECK: %[[VAL_187:.*]] = getelementptr inbounds float, float* %[[VAL_186]], i32 %[[VAL_185]] +// CHECK: %[[VAL_188:.*]] = load float, float* %[[VAL_187]], align 4, !invariant.load !4 +// CHECK: %[[VAL_189:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_190:.*]] = bitcast [64 x float]* %[[VAL_11]] to float* +// CHECK: %[[VAL_191:.*]] = getelementptr inbounds float, float* %[[VAL_190]], i32 %[[VAL_189]] +// CHECK: %[[VAL_192:.*]] = load float, float* %[[VAL_191]], align 4, !invariant.load !4 +// CHECK: %[[VAL_193:.*]] = fmul float %[[VAL_188]], %[[VAL_192]] +// CHECK: %[[VAL_194:.*]] = load float, float* bitcast ([4 x i8]* @8 to float*), align 4 +// CHECK: %[[VAL_195:.*]] = fmul float %[[VAL_193]], %[[VAL_194]] +// CHECK: %[[VAL_196:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_26]] to half* +// CHECK: %[[VAL_197:.*]] = getelementptr inbounds half, half* %[[VAL_196]], i32 %[[VAL_53]] +// CHECK: %[[VAL_198:.*]] = load half, half* %[[VAL_197]], align 2, !invariant.load !4 +// CHECK: %[[VAL_199:.*]] = load half, half* bitcast ([2 x i8]* @9 to half*), align 2 +// CHECK: %[[VAL_200:.*]] = fcmp ogt half %[[VAL_198]], %[[VAL_199]] +// CHECK: %[[VAL_201:.*]] = zext i1 %[[VAL_200]] to i8 +// CHECK: %[[VAL_202:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_23]] to half* +// CHECK: %[[VAL_203:.*]] = getelementptr inbounds half, half* %[[VAL_202]], i32 %[[VAL_53]] +// CHECK: %[[VAL_204:.*]] = load half, half* %[[VAL_203]], align 2, !invariant.load !4 +// CHECK: %[[VAL_205:.*]] = trunc i8 %[[VAL_201]] to i1 +// CHECK: %[[VAL_206:.*]] = select i1 %[[VAL_205]], half %[[VAL_204]], half %[[VAL_199]] +// CHECK: %[[VAL_207:.*]] = fpext half %[[VAL_206]] to float +// CHECK: %[[VAL_208:.*]] = load float, float* bitcast ([4 x i8]* @10 to float*), align 4 +// CHECK: %[[VAL_209:.*]] = fmul float %[[VAL_207]], %[[VAL_208]] +// CHECK: %[[VAL_210:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_211:.*]] = bitcast [64 x float]* %[[VAL_8]] to float* +// CHECK: %[[VAL_212:.*]] = getelementptr inbounds float, float* %[[VAL_211]], i32 %[[VAL_210]] +// CHECK: %[[VAL_213:.*]] = load float, float* %[[VAL_212]], align 4, !invariant.load !4 +// CHECK: %[[VAL_214:.*]] = fsub float %[[VAL_209]], %[[VAL_213]] +// CHECK: %[[VAL_215:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_216:.*]] = bitcast [64 x float]* %[[VAL_5]] to float* +// CHECK: %[[VAL_217:.*]] = getelementptr inbounds float, float* %[[VAL_216]], i32 %[[VAL_215]] +// CHECK: %[[VAL_218:.*]] = load float, float* %[[VAL_217]], align 4, !invariant.load !4 +// CHECK: %[[VAL_219:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_20]] to half* +// CHECK: %[[VAL_220:.*]] = getelementptr inbounds half, half* %[[VAL_219]], i32 %[[VAL_53]] +// CHECK: %[[VAL_221:.*]] = load half, half* %[[VAL_220]], align 2, !invariant.load !4 +// CHECK: %[[VAL_222:.*]] = fpext half %[[VAL_221]] to float +// CHECK: %[[VAL_223:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_224:.*]] = bitcast [64 x float]* %[[VAL_17]] to float* +// CHECK: %[[VAL_225:.*]] = getelementptr inbounds float, float* %[[VAL_224]], i32 %[[VAL_223]] +// CHECK: %[[VAL_226:.*]] = load float, float* %[[VAL_225]], align 4, !invariant.load !4 +// CHECK: %[[VAL_227:.*]] = load float, float* bitcast ([4 x i8]* @11 to float*), align 4 +// CHECK: %[[VAL_228:.*]] = fmul float %[[VAL_226]], %[[VAL_227]] +// CHECK: %[[VAL_229:.*]] = fsub float %[[VAL_222]], %[[VAL_228]] +// CHECK: %[[VAL_230:.*]] = fmul float %[[VAL_218]], %[[VAL_229]] +// CHECK: %[[VAL_231:.*]] = urem i32 %[[VAL_53]], 64 +// CHECK: %[[VAL_232:.*]] = bitcast [64 x float]* %[[VAL_2]] to float* +// CHECK: %[[VAL_233:.*]] = getelementptr inbounds float, float* %[[VAL_232]], i32 %[[VAL_231]] +// CHECK: %[[VAL_234:.*]] = load float, float* %[[VAL_233]], align 4, !invariant.load !4 +// CHECK: %[[VAL_235:.*]] = fdiv float %[[VAL_230]], %[[VAL_234]] +// CHECK: %[[VAL_236:.*]] = fsub float %[[VAL_214]], %[[VAL_235]] +// CHECK: %[[VAL_237:.*]] = fmul float %[[VAL_195]], %[[VAL_236]] +// CHECK: %[[VAL_238:.*]] = fptrunc float %[[VAL_237]] to half +// CHECK: %[[VAL_239:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_29]] to half* +// CHECK: %[[VAL_240:.*]] = getelementptr inbounds half, half* %[[VAL_239]], i32 %[[VAL_53]] +// CHECK: store half %[[VAL_238]], half* %[[VAL_240]], align 2 +// CHECK: %[[VAL_241:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_242:.*]] = bitcast [64 x float]* %[[VAL_14]] to float* +// CHECK: %[[VAL_243:.*]] = getelementptr inbounds float, float* %[[VAL_242]], i32 %[[VAL_241]] +// CHECK: %[[VAL_244:.*]] = load float, float* %[[VAL_243]], align 4, !invariant.load !4 +// CHECK: %[[VAL_245:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_246:.*]] = bitcast [64 x float]* %[[VAL_11]] to float* +// CHECK: %[[VAL_247:.*]] = getelementptr inbounds float, float* %[[VAL_246]], i32 %[[VAL_245]] +// CHECK: %[[VAL_248:.*]] = load float, float* %[[VAL_247]], align 4, !invariant.load !4 +// CHECK: %[[VAL_249:.*]] = fmul float %[[VAL_244]], %[[VAL_248]] +// CHECK: %[[VAL_250:.*]] = load float, float* bitcast ([4 x i8]* @12 to float*), align 4 +// CHECK: %[[VAL_251:.*]] = fmul float %[[VAL_249]], %[[VAL_250]] +// CHECK: %[[VAL_252:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_26]] to half* +// CHECK: %[[VAL_253:.*]] = getelementptr inbounds half, half* %[[VAL_252]], i32 %[[VAL_61]] +// CHECK: %[[VAL_254:.*]] = load half, half* %[[VAL_253]], align 2, !invariant.load !4 +// CHECK: %[[VAL_255:.*]] = load half, half* bitcast ([2 x i8]* @13 to half*), align 2 +// CHECK: %[[VAL_256:.*]] = fcmp ogt half %[[VAL_254]], %[[VAL_255]] +// CHECK: %[[VAL_257:.*]] = zext i1 %[[VAL_256]] to i8 +// CHECK: %[[VAL_258:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_23]] to half* +// CHECK: %[[VAL_259:.*]] = getelementptr inbounds half, half* %[[VAL_258]], i32 %[[VAL_61]] +// CHECK: %[[VAL_260:.*]] = load half, half* %[[VAL_259]], align 2, !invariant.load !4 +// CHECK: %[[VAL_261:.*]] = trunc i8 %[[VAL_257]] to i1 +// CHECK: %[[VAL_262:.*]] = select i1 %[[VAL_261]], half %[[VAL_260]], half %[[VAL_255]] +// CHECK: %[[VAL_263:.*]] = fpext half %[[VAL_262]] to float +// CHECK: %[[VAL_264:.*]] = load float, float* bitcast ([4 x i8]* @14 to float*), align 4 +// CHECK: %[[VAL_265:.*]] = fmul float %[[VAL_263]], %[[VAL_264]] +// CHECK: %[[VAL_266:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_267:.*]] = bitcast [64 x float]* %[[VAL_8]] to float* +// CHECK: %[[VAL_268:.*]] = getelementptr inbounds float, float* %[[VAL_267]], i32 %[[VAL_266]] +// CHECK: %[[VAL_269:.*]] = load float, float* %[[VAL_268]], align 4, !invariant.load !4 +// CHECK: %[[VAL_270:.*]] = fsub float %[[VAL_265]], %[[VAL_269]] +// CHECK: %[[VAL_271:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_272:.*]] = bitcast [64 x float]* %[[VAL_5]] to float* +// CHECK: %[[VAL_273:.*]] = getelementptr inbounds float, float* %[[VAL_272]], i32 %[[VAL_271]] +// CHECK: %[[VAL_274:.*]] = load float, float* %[[VAL_273]], align 4, !invariant.load !4 +// CHECK: %[[VAL_275:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_20]] to half* +// CHECK: %[[VAL_276:.*]] = getelementptr inbounds half, half* %[[VAL_275]], i32 %[[VAL_61]] +// CHECK: %[[VAL_277:.*]] = load half, half* %[[VAL_276]], align 2, !invariant.load !4 +// CHECK: %[[VAL_278:.*]] = fpext half %[[VAL_277]] to float +// CHECK: %[[VAL_279:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_280:.*]] = bitcast [64 x float]* %[[VAL_17]] to float* +// CHECK: %[[VAL_281:.*]] = getelementptr inbounds float, float* %[[VAL_280]], i32 %[[VAL_279]] +// CHECK: %[[VAL_282:.*]] = load float, float* %[[VAL_281]], align 4, !invariant.load !4 +// CHECK: %[[VAL_283:.*]] = load float, float* bitcast ([4 x i8]* @15 to float*), align 4 +// CHECK: %[[VAL_284:.*]] = fmul float %[[VAL_282]], %[[VAL_283]] +// CHECK: %[[VAL_285:.*]] = fsub float %[[VAL_278]], %[[VAL_284]] +// CHECK: %[[VAL_286:.*]] = fmul float %[[VAL_274]], %[[VAL_285]] +// CHECK: %[[VAL_287:.*]] = urem i32 %[[VAL_61]], 64 +// CHECK: %[[VAL_288:.*]] = bitcast [64 x float]* %[[VAL_2]] to float* +// CHECK: %[[VAL_289:.*]] = getelementptr inbounds float, float* %[[VAL_288]], i32 %[[VAL_287]] +// CHECK: %[[VAL_290:.*]] = load float, float* %[[VAL_289]], align 4, !invariant.load !4 +// CHECK: %[[VAL_291:.*]] = fdiv float %[[VAL_286]], %[[VAL_290]] +// CHECK: %[[VAL_292:.*]] = fsub float %[[VAL_270]], %[[VAL_291]] +// CHECK: %[[VAL_293:.*]] = fmul float %[[VAL_251]], %[[VAL_292]] +// CHECK: %[[VAL_294:.*]] = fptrunc float %[[VAL_293]] to half +// CHECK: %[[VAL_295:.*]] = bitcast [128 x [112 x [112 x [64 x half]]]]* %[[VAL_29]] to half* +// CHECK: %[[VAL_296:.*]] = getelementptr inbounds half, half* %[[VAL_295]], i32 %[[VAL_61]] +// CHECK: store half %[[VAL_294]], half* %[[VAL_296]], align 2 +// CHECK: br label %[[VAL_71]] + +%fused_computation.1 (param_0.5: f32[64], param_1.3088: f32[64], param_2.2116: f32[64], param_3.974: f32[64], param_4.1162: f32[64], param_5.893: f32[64], param_6.809: f16[128,64,112,112], param_7.770: f16[128,64,112,112], param_8.637: f16[128,64,112,112]) -> f16[128,64,112,112] { + %param_4.1162 = f32[64]{0} parameter(4) + %broadcast.2313 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_4.1162), dimensions={1} + %param_3.974 = f32[64]{0} parameter(3) + %broadcast.1844 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_3.974), dimensions={1} + %multiply.1049 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.2313, f32[128,64,112,112]{1,3,2,0} %broadcast.1844) + %constant_1404 = f32[] constant(6.22807704e-07) + %broadcast.1843 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1404), dimensions={} + %multiply.1048 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1049, f32[128,64,112,112]{1,3,2,0} %broadcast.1843) + %param_8.637 = f16[128,64,112,112]{1,3,2,0} parameter(8) + %constant_3626 = f16[] constant(0) + %broadcast.4770 = f16[128,64,112,112]{1,3,2,0} broadcast(f16[] %constant_3626), dimensions={} + %compare.259 = pred[128,64,112,112]{1,3,2,0} compare(f16[128,64,112,112]{1,3,2,0} %param_8.637, f16[128,64,112,112]{1,3,2,0} %broadcast.4770), direction=GT + %param_7.770 = f16[128,64,112,112]{1,3,2,0} parameter(7) + %select.254 = f16[128,64,112,112]{1,3,2,0} select(pred[128,64,112,112]{1,3,2,0} %compare.259, f16[128,64,112,112]{1,3,2,0} %param_7.770, f16[128,64,112,112]{1,3,2,0} %broadcast.4770) + %convert.108 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %select.254) + %constant_1390 = f32[] constant(1605632) + %broadcast.1841 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[] %constant_1390), dimensions={} + %multiply.1046 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %convert.108, f32[128,64,112,112]{1,3,2,0} %broadcast.1841) + %param_2.2116 = f32[64]{0} parameter(2) + %broadcast.1840 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_2.2116), dimensions={1} + %subtract.266 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %multiply.1046, f32[128,64,112,112]{1,3,2,0} %broadcast.1840) + %param_1.3088 = f32[64]{0} parameter(1) + %broadcast.1839 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_1.3088), dimensions={1} + %param_6.809 = f16[128,64,112,112]{1,3,2,0} parameter(6) + %convert.644 = f32[128,64,112,112]{1,3,2,0} convert(f16[128,64,112,112]{1,3,2,0} %param_6.809) + %param_5.893 = f32[64]{0} parameter(5) + %broadcast.3388 = f32[64]{0} broadcast(f32[] %constant_1404), dimensions={} + %multiply.2336 = f32[64]{0} multiply(f32[64]{0} %param_5.893, f32[64]{0} %broadcast.3388) + %broadcast.3387 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %multiply.2336), dimensions={1} + %subtract.591 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %convert.644, f32[128,64,112,112]{1,3,2,0} %broadcast.3387) + %multiply.1045 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %broadcast.1839, f32[128,64,112,112]{1,3,2,0} %subtract.591) + %param_0.5 = f32[64]{0} parameter(0) + %broadcast.1838 = f32[128,64,112,112]{1,3,2,0} broadcast(f32[64]{0} %param_0.5), dimensions={1} + %divide.212 = f32[128,64,112,112]{1,3,2,0} divide(f32[128,64,112,112]{1,3,2,0} %multiply.1045, f32[128,64,112,112]{1,3,2,0} %broadcast.1838) + %subtract.265 = f32[128,64,112,112]{1,3,2,0} subtract(f32[128,64,112,112]{1,3,2,0} %subtract.266, f32[128,64,112,112]{1,3,2,0} %divide.212) + %multiply.1044 = f32[128,64,112,112]{1,3,2,0} multiply(f32[128,64,112,112]{1,3,2,0} %multiply.1048, f32[128,64,112,112]{1,3,2,0} %subtract.265) + ROOT %convert.107 = f16[128,64,112,112]{1,3,2,0} convert(f32[128,64,112,112]{1,3,2,0} %multiply.1044) +} + +ENTRY main { + %get-tuple-element.1532 = f32[64]{0} parameter(0) + %get-tuple-element.876 = f32[64]{0} parameter(1) + %get-tuple-element.877 = f32[64]{0} parameter(2) + %get-tuple-element.1530 = f32[64]{0} parameter(3) + %arg112.113 = f32[64]{0} parameter(4) + %get-tuple-element.881 = f32[64]{0} parameter(5) + %get-tuple-element.872 = f16[128,64,112,112]{1,3,2,0} parameter(6) + %select-and-scatter.3626 = f16[128,64,112,112]{1,3,2,0} parameter(7) + %fusion.845 = f16[128,64,112,112]{1,3,2,0} parameter(8) + + ROOT %fusion.1 = f16[128,64,112,112]{1,3,2,0} fusion(f32[64]{0} %get-tuple-element.1532, f32[64]{0} %get-tuple-element.876, f32[64]{0} %get-tuple-element.877, f32[64]{0} %get-tuple-element.1530, f32[64]{0} %arg112.113, f32[64]{0} %get-tuple-element.881, f16[128,64,112,112]{1,3,2,0} %get-tuple-element.872, f16[128,64,112,112]{1,3,2,0} %select-and-scatter.3626, f16[128,64,112,112]{1,3,2,0} %fusion.845), kind=kLoop, calls=%fused_computation.1 +} diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index d1bece038e0..6ed378adfeb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -833,7 +833,7 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie(); auto expected_ir = R"( -; CHECK: shared_cache_{{[0-9]*}} = private addrspace({{[0-9]*}}) global [1 x [32 x float]] +; CHECK: shared_cache_{{[0-9]*}} = private unnamed_addr addrspace({{[0-9]*}}) global [1 x [32 x float]] )"; CompileAndVerifyIr(std::move(hlo_module), expected_ir, /*match_optimized_ir=*/true); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 8ec00d73711..95cb01dd17e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -50,11 +50,10 @@ TEST_F(GpuNoAliasTest, Concat) { auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); - CompileAndVerifyIr(std::move(hlo_module), - R"(CHECK-LABEL: define{{.*}}void @fusion - CHECK-SAME: i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %[[OUTPUT_ALLOC:[a-z0-9]*]] - CHECK: %fusion.raw = {{.*}} %[[OUTPUT_ALLOC]])", - /*match_optimized_ir=*/false); + CompileAndVerifyIr( + std::move(hlo_module), + R"(CHECK: define{{.*}}void @fusion(i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %{{.*}}, i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %{{.*}}, i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %{{.*}}))", + /*match_optimized_ir=*/false); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index f625abe6612..e929efb6d54 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -3,14 +3,12 @@ // CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_32:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 -// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]* -// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 -// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [3 x [3 x i32]]* // CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0 // CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [2 x i32]* // CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0 // CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [2 x [3 x i32]]* +// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [3 x [3 x i32]]* // CHECK: %[[VAL_12:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_14:.*]] = mul nuw nsw i32 %[[VAL_12]], 6 @@ -75,14 +73,12 @@ ENTRY main { // CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) { // CHECK: entry: // CHECK: %[[VAL_60:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 -// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32* -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds i8, i8* %[[VAL_41:.*]], i64 0 -// CHECK: %[[VAL_42:.*]] = bitcast i8* %[[VAL_40]] to i32* // CHECK: %[[VAL_43:.*]] = getelementptr inbounds i8, i8* %[[VAL_44:.*]], i64 0 // CHECK: %[[VAL_45:.*]] = bitcast i8* %[[VAL_43]] to [0 x i32]* // CHECK: %[[VAL_46:.*]] = getelementptr inbounds i8, i8* %[[VAL_47:.*]], i64 0 // CHECK: %[[VAL_48:.*]] = bitcast i8* %[[VAL_46]] to i32* +// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 +// CHECK: %[[VAL_39:.*]] = bitcast i8* %[[VAL_37]] to i32* // CHECK: %[[VAL_49:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_50:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 // CHECK: %[[VAL_51:.*]] = mul nuw nsw i32 %[[VAL_49]], 1 @@ -135,14 +131,12 @@ ENTRY main { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 // CHECK: %[[VAL_98:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0 -// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]* -// CHECK: %[[VAL_68:.*]] = getelementptr inbounds i8, i8* %[[VAL_69:.*]], i64 0 -// CHECK: %[[VAL_70:.*]] = bitcast i8* %[[VAL_68]] to [3 x [3 x i32]]* // CHECK: %[[VAL_71:.*]] = getelementptr inbounds i8, i8* %[[VAL_72:.*]], i64 0 // CHECK: %[[VAL_73:.*]] = bitcast i8* %[[VAL_71]] to [2 x i32]* // CHECK: %[[VAL_74:.*]] = getelementptr inbounds i8, i8* %[[VAL_75:.*]], i64 0 // CHECK: %[[VAL_76:.*]] = bitcast i8* %[[VAL_74]] to [2 x [3 x i32]]* +// CHECK: %[[VAL_65:.*]] = getelementptr inbounds i8, i8* %[[VAL_66:.*]], i64 0 +// CHECK: %[[VAL_67:.*]] = bitcast i8* %[[VAL_65]] to [3 x [3 x i32]]* // CHECK: %[[VAL_77:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_78:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_79:.*]] = mul nuw nsw i32 %[[VAL_77]], 6 @@ -180,7 +174,7 @@ ENTRY main { // CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]] // CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4 // CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4 -// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]]) +// CHECK: call void @{{.+}}(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]]) // CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4 // CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst // CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0 @@ -219,14 +213,12 @@ ENTRY main { // CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_146:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 -// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]* -// CHECK: %[[VAL_121:.*]] = getelementptr inbounds i8, i8* %[[VAL_122:.*]], i64 0 -// CHECK: %[[VAL_123:.*]] = bitcast i8* %[[VAL_121]] to [4 x i32]* // CHECK: %[[VAL_124:.*]] = getelementptr inbounds i8, i8* %[[VAL_125:.*]], i64 0 // CHECK: %[[VAL_126:.*]] = bitcast i8* %[[VAL_124]] to i32* // CHECK: %[[VAL_127:.*]] = getelementptr inbounds i8, i8* %[[VAL_128:.*]], i64 0 // CHECK: %[[VAL_129:.*]] = bitcast i8* %[[VAL_127]] to i32* +// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 +// CHECK: %[[VAL_120:.*]] = bitcast i8* %[[VAL_118]] to [4 x i32]* // CHECK: %[[VAL_130:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_131:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 // CHECK: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_130]], 1 diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 4fb7edd0104..4ddd8ce5146 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1932,6 +1932,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleReduceWindow(HloInstruction* reduce_window) override { + if (reduce_window->shape().IsTuple()) { + return Status(tensorflow::error::UNIMPLEMENTED, + "Variadic reduce window op is not yet fully supported."); + } auto operand = reduce_window->operand(0); const Window& window = reduce_window->window(); HloComputation* function = reduce_window->to_apply(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 41488dcdaaa..b24c35c4c69 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -515,11 +515,23 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) + << "Reduce window should have an even number of operands but " + "sees " + << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduceWindow(shape, operands(0), operands(1), - proto.window(), computations(0)); + { + const auto reduce_operands = all_operands(); + auto inputs = absl::MakeSpan(reduce_operands) + .subspan(0, reduce_operands.size() / 2); + auto init_values = + absl::MakeSpan(reduce_operands) + .subspan(reduce_operands.size() / 2, reduce_operands.size()); + instruction = CreateReduceWindow(shape, inputs, init_values, + proto.window(), computations(0)); + } break; case HloOpcode::kSelectAndScatter: TF_RET_CHECK(proto.called_computation_ids_size() == 2) @@ -1273,6 +1285,13 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, shape, operand, init_value, window, reduce_computation); } +/* static */ std::unique_ptr HloInstruction::CreateReduceWindow( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation) { + return absl::make_unique( + shape, operands, init_values, window, reduce_computation); +} /* static */ std::unique_ptr HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 9675a2f0f0d..5901e446df5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -830,6 +830,16 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The reduce_computation being applied,now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // valueN], and returns an N-tuple. The operands and init_values now each + // contain a span of N input arrays and n initial values. + static std::unique_ptr CreateReduceWindow( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation); + // Creates a batch-norm-training instruction. static std::unique_ptr CreateBatchNormTraining( const Shape& shape, HloInstruction* operand, HloInstruction* scale, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 45b2d885d8e..8cb7d91f5ac 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2237,9 +2237,21 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloReduceWindowInstruction::HloReduceWindowInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) + : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1), + absl::MakeSpan(&init_value, 1), window, + reduce_computation) {} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { - AppendOperand(operand); - AppendOperand(init_value); + for (auto* operand : operands) { + AppendOperand(operand); + } + for (auto* init_value : init_values) { + AppendOperand(init_value); + } AppendComputation(reduce_computation); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 88e874347bd..848674fc604 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -1294,10 +1295,43 @@ class HloReduceWindowInstruction : public HloInstruction { HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + explicit HloReduceWindowInstruction( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the number of input arrays (and, consequentially, the number of + // init values) this reduce has. + int64 input_count() const { return operand_count() / 2; } + // Returns the input tensors to be reduced. + absl::Span input_arrays() const { + return absl::MakeSpan(operands()).subspan(0, input_count()); + } + // Returns the init values of the reduction. + absl::Span init_values() const { + return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); + } + // Returns the shapes of input tensors to be reduced. + absl::InlinedVector input_array_shapes() const { + absl::InlinedVector shapes; + for (const auto* op : input_arrays()) { + VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n"; + shapes.push_back(&op->shape()); + VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n"; + } + return shapes; + } + // Returns the init values of the reduction. + absl::InlinedVector init_value_shapes() const { + absl::InlinedVector shapes; + for (const auto* op : init_values()) { + shapes.push_back(&op->shape()); + } + return shapes; + } private: std::vector ExtraAttributesToStringImpl( @@ -1310,6 +1344,7 @@ class HloReduceWindowInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; + Window window_; }; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index b50c7d9a584..e14d86e6bc0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -119,7 +119,7 @@ namespace xla { V(kRecvDone, "recv-done", 1) \ V(kReduce, "reduce", kHloOpcodeIsVariadic) \ V(kReducePrecision, "reduce-precision", 1) \ - V(kReduceWindow, "reduce-window", 2) \ + V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \ V(kRemainder, "remainder", 2) \ V(kReplicaId, "replica-id", 0) \ V(kReshape, "reshape", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cceb60a70e9..95bb81c60f6 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -65,6 +65,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kRng: case HloOpcode::kSort: case HloOpcode::kTuple: + case HloOpcode::kReduceWindow: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; default: diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 977f6ee8ea6..f70f91d7c26 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -42,6 +42,10 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { HloSharding HloSharding::PartialTile( const Array& group_tile_assignment, absl::Span> replication_groups) { + CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size()); + if (replication_groups.size() == 1) { + return Replicate(); + } auto new_tile_dims = group_tile_assignment.dimensions(); new_tile_dims.push_back(replication_groups[0].size()); auto new_tile_assignment = Array(new_tile_dims); @@ -56,6 +60,9 @@ HloSharding HloSharding::PartialTile( HloSharding HloSharding::PartialTile( const Array& tile_assignment_last_dim_replicate) { + if (tile_assignment_last_dim_replicate.num_dimensions() == 1) { + return Replicate(); + } if (tile_assignment_last_dim_replicate.dimensions().back() == 1) { auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions(); new_tile_dims.pop_back(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 9940b032558..c8b4d2f013d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -264,3 +264,13 @@ cc_library( "@llvm-project//llvm:Core", ], ) + +tf_cc_test( + name = "ir_array_test", + srcs = ["ir_array_test.cc"], + deps = [ + ":ir_array", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f8514a6cba3..164c8f7e1c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" @@ -44,32 +43,37 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - auto cache = generated_value_cache_.find(hlo); - if (cache != generated_value_cache_.end()) { - auto key = std::make_pair(b_->GetInsertBlock(), index.multidim()); - if (llvm::Value* generated_value = - FindOrDefault(cache->second, key, nullptr)) { - VLOG(3) << "The cached generated value is reused."; - return generated_value; - } - auto null_key = std::make_pair(nullptr, index.multidim()); - if (llvm::Value* generated_value = - FindOrDefault(cache->second, null_key, nullptr)) { + if (llvm::Value* generated_value = FindOrDefault( + generated_value_cache_[hlo], index.multidim(), nullptr)) { + llvm::BasicBlock* generated_value_bb = nullptr; + if (auto* generated_instruction = + llvm::dyn_cast(generated_value)) { + generated_value_bb = generated_instruction->getParent(); + } + // Ideally, we should be able to reuse the cached generated value if it + // dominates the current insertion block. However, the check for dominance + // can be expensive and unreliable when the function is being constructed. + // + // It's also worth experimenting what if we don't do caching at all. + // LLVM's CSE or GVN should be able to easily merge common subexpressions + // that would be regenerated without caching. But this might increase the + // JIT compilation time. + if (generated_value_bb == nullptr || + generated_value_bb == b_->GetInsertBlock()) { VLOG(3) << "The cached generated value is reused."; return generated_value; } + VLOG(3) << "The cached generated value can't be reused, because it is in " + "a different BB (" + << generated_value_bb->getName().str() + << ") from the current insertion block (" + << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - llvm::BasicBlock* generated_value_bb = nullptr; - if (auto* generated_instruction = - llvm::dyn_cast(generated_value)) { - generated_value_bb = generated_instruction->getParent(); - } - generated_value_cache_[hlo][std::make_pair( - generated_value_bb, index.multidim())] = generated_value; + generated_value_cache_[hlo][index.multidim()] = generated_value; return generated_value; }; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index e19e970cb24..d13b0262180 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" @@ -154,10 +153,9 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - absl::flat_hash_map>, - llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 73d430e2c54..6da4d08f182 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -527,5 +527,27 @@ IrArray IrArray::CastToShape(const Shape& new_shape, return new_irarray; } +bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) { + // Compute strides for two sides of the comparison. Sometimes different shapes + // give the same strides: + // [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0} + // which should be considered compatible. + const auto get_strides = [](const Shape& shape) { + int rank = shape.dimensions().size(); + int64 stride = 1; + std::vector strides; + for (int i = 0; i < rank; i++) { + auto dim = shape.dimensions(shape.layout().minor_to_major(i)); + if (dim != 1) { + stride *= dim; + strides.push_back(stride); + } + } + return strides; + }; + + return get_strides(a) == get_strides(b); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 32273de38ea..dfc49ce3dde 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -110,15 +110,12 @@ class IrArray { bool LinearValidOnShape(const Shape& a) const; + static bool ShapeIsCompatible(const Shape& a, const Shape& b); + bool ShapeIsCompatible(const Shape& a) const { - Shape own_shape = ShapeUtil::MakeShape(a.element_type(), dims_); - *own_shape.mutable_layout() = layout_; - // The shape 'a' could have dynamic dimensions set. Before we check for - // equality, we need to copy the information which dimensions are dynamic. - for (int64 i = 0; i < a.rank(); ++i) { - own_shape.set_dynamic_dimension(i, a.is_dynamic_dimension(i)); - } - return ShapeUtil::Equal(own_shape, a); + return ShapeIsCompatible( + a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_, + layout_.minor_to_major())); } // Given that "this" is the target index of a reshape from `input_shape` diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array_test.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array_test.cc new file mode 100644 index 00000000000..7f464b76e4f --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array_test.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace llvm_ir { +namespace { + +TEST(IrArrayTest, TestShapeIsCompatible) { + xla::Shape a = ShapeUtil::MakeShapeWithLayout(F32, {1, 10, 20}, {2, 1, 0}); + xla::Shape b = ShapeUtil::MakeShapeWithLayout(F32, {1, 10, 20}, {2, 0, 1}); + xla::Shape c = ShapeUtil::MakeShapeWithLayout(F32, {10, 1, 20}, {2, 1, 0}); + + xla::Shape d = ShapeUtil::MakeShapeWithLayout(F32, {1, 10, 30}, {2, 1, 0}); + xla::Shape e = ShapeUtil::MakeShapeWithLayout(F32, {1, 10, 30}, {2, 0, 1}); + xla::Shape f = ShapeUtil::MakeShapeWithLayout(F32, {10, 1, 30}, {2, 1, 0}); + + EXPECT_TRUE(IrArray::Index::ShapeIsCompatible(a, b)); + EXPECT_TRUE(IrArray::Index::ShapeIsCompatible(a, c)); + EXPECT_FALSE(IrArray::Index::ShapeIsCompatible(a, d)); + EXPECT_FALSE(IrArray::Index::ShapeIsCompatible(a, e)); + EXPECT_FALSE(IrArray::Index::ShapeIsCompatible(a, f)); +} + +} // namespace +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 8fb01fa45f2..5b133a521e3 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -1207,6 +1207,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { bool repacked = false; for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { + AddRequiredAssignmentsForColocatedIntervals(colocated_intervals); bool final_retry = (retry_number == options_.max_retries - 1); options_.prefetch_interval_picker->SetRetryNumber(retry_number); Result result = @@ -1217,7 +1218,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { (!final_retry && result_failed_because_of_async_copy(result))) { UncommitPendingChunks(absl::MakeSpan(allocation_values)); VLOG(2) << "Couldn't allocate. Retry number " << retry_number; - } else if (result_is(result, Result::kFailOutOfMemory) && + } else if ((result_is(result, Result::kFailOutOfMemory) || + options_.repack_after_every_allocation) && num_repacks_ < options_.max_repacks && !repacked) { UncommitPendingChunks(absl::MakeSpan(allocation_values)); ++num_repacks_; @@ -1251,10 +1253,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return result_; } -void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( +void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals( absl::Span - colocated_intervals, - std::vector& allocation_values) { + colocated_intervals) { // TODO(berkin): For now, place the phi values due to conditionals in // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -1273,7 +1274,12 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( } } } +} +void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values) { // Create AllocationValues for all the colocated intervals. for (const auto& colocated_interval : colocated_intervals) { CreateAllocationValues(*colocated_interval, allocation_values); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index b50015f82b3..cb459c68be1 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -462,6 +462,9 @@ class MemorySpaceAssignment { // max_repacks is greater than 0. MemorySpaceAssignmentRepacker* repacker = nullptr; + // This is only useful for testing, repack after every allocation. + bool repack_after_every_allocation = false; + // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). bool allocate_across_sequential_calls = false; @@ -1201,6 +1204,11 @@ class AlternateMemoryBestFitHeap absl::optional AliasedRequiredAssignmentForUse( const AllocationValue::Use& use) const; + // Goes through the colocated intervals and adds any required assignment. + void AddRequiredAssignmentsForColocatedIntervals( + absl::Span + colocated_intervals); + // Propagates aliased required assignment for a given position. void AddAliasedRequiredAssignment( const HloInstruction* instruction, ShapeIndex index, diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index f574bd37937..187076abe8a 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -4136,10 +4136,12 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { public: explicit FakeMemorySpaceAssignmentRepacker( absl::flat_hash_map, int64>& repack_map, - std::function)> check_fun = nullptr) + std::function)> check_fun = nullptr, + bool always_return_modified = false) : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8), repack_map_(repack_map), - check_fun_(check_fun) {} + check_fun_(check_fun), + always_return_modified_(always_return_modified) {} StatusOr Repack(absl::Span allocations) override { bool modified = false; @@ -4173,13 +4175,14 @@ class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { check_fun_(allocations); } - return modified; + return always_return_modified_ || modified; } private: // A map from (start_time, offset) to new_offset. absl::flat_hash_map, int64> repack_map_; std::function)> check_fun_; + bool always_return_modified_; }; TEST_P(MemorySpaceAssignmentTest, Repack) { @@ -4418,6 +4421,66 @@ TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) { options); } +TEST_P(MemorySpaceAssignmentTest, + RepackShouldntEraseRequiredAssignmentForConditionalOutput) { + // This is a test case for b/171040271. Repacks erase the required assignments + // (since some required assignments are inserted conditionally based on + // allocation decisions), including the fact that conditional outputs are + // always required to get assignments in the default memory. After repacking, + // this required assignment was never added back, causing conditionals to get + // alternate-memory allocations. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]) parameter(0) + gte = f32[3] get-tuple-element(p0), index=0 + neg1 = f32[3] negate(gte) + ROOT tuple1 = (f32[3]) tuple(neg1) + } + + false_computation { + p0 = (f32[3]) parameter(0) + gte = f32[3] get-tuple-element(p0), index=0 + neg2 = f32[3] negate(gte) + ROOT tuple2 = (f32[3]) tuple(neg2) + } + + ENTRY entry { + p0 = f32[3] parameter(0) + p1 = pred[] parameter(1) + copy = f32[3] copy(p0) + tuple = (f32[3]) tuple(copy) + conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation + ROOT gte = f32[3] get-tuple-element(conditional), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + absl::flat_hash_map, int64> repack_map; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map, nullptr, + /*always_return_modified=*/true); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 10; + options.repacker = &repacker; + options.repack_after_every_allocation = true; + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*buffer_interval_compare=*/{}, &prefetch_interval_picker, + options); + // Make sure the root of the entry computation is in the default memory space. + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->shape() + .layout() + .memory_space(), + kDefaultMemorySpace); +} + TEST_P(MemorySpaceAssignmentTest, Determinism) { // Run memory space assignment a few times to make sure every time it compiles // to the same thing. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index a664a316e13..9194e44ba5f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -70,7 +70,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Legalize from HLO to LHLO. pm.addPass(::mlir::mhlo::createLegalizeToLhloPass()); // Moving `AllocOp`s and inserting missing `DeallocOp`s - pm.addPass(::mlir::createBufferPlacementPass()); + pm.addPass(::mlir::createBufferHoistingPass()); + pm.addPass(::mlir::createBufferDeallocationPass()); // Next, we can strip the outer fusion operation. pm.addPass(createFusionOpRemoverPass()); // Remove unnecessary LHLO copies. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index b275dd4525f..1457fa5df1d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -322,12 +322,9 @@ Status LhloDialectEmitter::HandleGather(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); OpBuilder func_builder(function.getBody()); - // TODO(pifon): Clean-up LHLO GatherOp to be consistent with HLO GatherOp. func_builder.create( getLocation(instr), function.getArgument(0), function.getArgument(1), - dim_numbers.index_vector_dim(), dim_numbers.offset_dims(), slice_sizes, - dim_numbers.collapsed_slice_dims(), dim_numbers.start_index_map(), - function.getArgument(2)); + dim_numbers, slice_sizes, function.getArgument(2)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 52af857efe5..f00f46b83c1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -562,9 +562,20 @@ StatusOr> MlirCompilerImpl::RunBackend( auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), GetGpuVersion(stream_exec), config, GetLibdeviceDir(config))); - TF_ASSIGN_OR_RETURN( - auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), - gpu::PtxOptsFromConfig(config))); + // Allow to fallback to the driver compilation when ptxas isn't able to + // compile. + StatusOr> maybe_cubin = + se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), + gpu::PtxOptsFromConfig(config)); + std::vector cubin; + if (maybe_cubin.ok()) { + cubin = std::move(maybe_cubin).ValueOrDie(); + } else if (maybe_cubin.status().code() == + tensorflow::error::Code::UNIMPLEMENTED) { + xla::gpu::WarnIfBadDriverJITVersion(); + } else { + return maybe_cubin.status(); + } auto thunk_schedule = absl::make_unique( std::make_unique(std::move(thunk_sequence)), diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo index 470ae348740..6a4f020b850 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo @@ -12,9 +12,11 @@ ENTRY %Gather (x: f32[100,10], y: s64[4,6]) -> f32[4,6,10] { // CHECK: func @gather(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], // CHECK-SAME: %[[RESULT:.*]]: [[RTYPE:.*]]) { // CHECK-NEXT: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) { -// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>, -// CHECK-SAME: index_vector_dim = 2 : i64, -// CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>, -// CHECK-SAME: slice_sizes = dense<[1, 10]> : tensor<2xi64>, -// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> +// CHECK-SAME: dimension_numbers = { +// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>, +// CHECK-SAME: index_vector_dim = 2 : i64, +// CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>, +// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> +// CHECK-SAME: }, +// CHECK-SAME: slice_sizes = dense<[1, 10]> : tensor<2xi64> // CHECK-SAME: } : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> () diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index eb29fa89098..8ebb522d6a8 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -2125,6 +2125,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); +XLA_TERNOP_PATTERN(SelectAndScatter); #undef XLA_TERNOP_PATTERN namespace detail { diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a96c9c34260..43e3ea15b5f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2084,7 +2084,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, arg_shapes.size()); } int64 num_reduced_args = arg_shapes.size() / 2; - auto reduced_args = arg_shapes.subspan(0, num_reduced_args); // Check that all of the reduced tensors have the same dimensions. The element // types may be different. @@ -2097,7 +2096,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(*reduced_args[i])); } } - // Check that the dimensions to reduce are in-bounds for the given shape. // We've already verified all reduced tensors have the same dimensions, so it // doesn't matter which one we choose. @@ -2156,6 +2154,43 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferReduceWindowShape(operand_shape, init_value_shape, window); } +/* static */ StatusOr ShapeInference::InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window, const ProgramShape& to_apply_shape) { + auto number_of_input = operands.size(); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < number_of_input; ++i) { + if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) { + return InvalidArgument( + "All reduced tensors must have the same dimension. Tensor 0 has " + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*operands[0]), i, + ShapeUtil::HumanString(*operands[i])); + } + } + std::vector operand_element_type_vec; + for (const Shape* s : operands) { + operand_element_type_vec.push_back(s->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values, + operand_element_type_vec, + /*inputs=*/number_of_input)); + std::vector output_shape_vec; + for (int i = 0; i < operands.size(); ++i) { + TF_ASSIGN_OR_RETURN( + auto cur_output_shape, + InferReduceWindowShape(*operands[i], *init_values[i], window)); + output_shape_vec.push_back(cur_output_shape); + } + if (ShapeUtil::IsScalar(to_apply_shape.result())) { + CHECK_EQ(output_shape_vec.size(), 1); + return output_shape_vec[0]; + } else { + return ShapeUtil::MakeTupleShape(output_shape_vec); + } +} + /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index f03e4e5fa98..eb969873fd0 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -164,10 +164,16 @@ class ShapeInference { static StatusOr InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape(const Shape& operand_shape, const Shape& init_value, const Window& window); + static StatusOr InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window, const ProgramShape& to_apply_shape); + + static StatusOr InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window); // Infers the shape produced by scattering the given source shape to the // selected indices of each window on the operand shape. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 00ecb254a17..73bbe5cb3bd 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -912,6 +913,32 @@ TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { inferred_status.ValueOrDie())); } +TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + std::vector args = {&f32_arg_shape, &s32_arg_shape}; + std::vector inits = {&f32_, &s32_}; + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + std::vector window_dimensions = {1, 2, 4}; + std::vector window_strides = {1, 1, 1}; + std::vector> padding_values = + MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions, + window_strides, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN( + Window window, + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding_values, {}, {})); + auto inferred_status = ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n"; + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}), + ShapeUtil::MakeShape(S32, {5, 2, 0})}), + inferred_status.ValueOrDie())); +} + TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); @@ -948,6 +975,29 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { HasSubstr("must have at least 2 arguments, has 0")); } +TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + std::vector args = {&f32_arg_shape, &s32_arg_shape}; + std::vector inits = {&f32_, &s32_}; + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + std::vector window_dimensions = {1, 2, 4}; + std::vector window_strides = {1, 1, 1}; + std::vector> padding_values = + MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions, + window_strides, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN( + Window window, + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding_values, {}, {})); + auto inferred_status = ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + EXPECT_FALSE(inferred_status.status().ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("f32[] vs s32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index 9ebaaa8242f..dfcba8f0a32 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -20,6 +20,7 @@ cc_library( srcs = [ "convolution_handler.cc", "dot_handler.cc", + "fft_handler.cc", "spmd_partitioner.cc", "spmd_partitioner_util.cc", ], diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 30fc8355402..45bd79bfc75 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -91,6 +91,17 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { namespace { +std::vector GetAllDevicesInOrder(const HloSharding& sharding) { + CHECK(!sharding.IsTileMaximal()); + std::vector results; + results.reserve(sharding.tile_assignment().num_elements()); + sharding.tile_assignment().Each( + [&](absl::Span /* indices */, int64 device) { + results.push_back(device); + }); + return results; +} + StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, @@ -478,7 +489,8 @@ StatusOr PartitionBaseCase( auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); auto ar = lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), + {GetAllDevicesInOrder(lhs.sharding())}, (*lhs.state().next_channel_id)++); ar->set_sharding(HloSharding::Replicate()); return PartitionedHlo(ar, output_base_shape, lhs.state()) @@ -581,7 +593,8 @@ StatusOr PartitionBaseCase( TF_ASSIGN_OR_RETURN( auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); return lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), + {GetAllDevicesInOrder(lhs.sharding())}, (*lhs.state().next_channel_id)++); } return nullptr; @@ -940,7 +953,7 @@ StatusOr PartitionDotGroupOnNonContracting( other.sharding(), {other_group_dims[0]}, {other.sharding().tile_assignment().dimensions().back() / group_count}), - output_grouped); + output_grouped, /*ignore_group_order=*/true); other = other.Reshard(UngroupSharding(grouped)); partially_replicated_other = other.hlo(); top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding()); @@ -1062,7 +1075,8 @@ StatusOr PartitionDotGroupOnContracting( {output_sharding.tile_assignment().num_dimensions() - 1}, {output_sharding.tile_assignment().dimensions().back() / group_count}), - lhs_grouped); + lhs_grouped, + /*ignore_group_order=*/true); outer_output_tmp_sharding = UngroupSharding(grouped); inner_output_sharding = std::move(grouped.sharding); } else { @@ -1125,7 +1139,8 @@ StatusOr PartitionDotGroupOnContracting( inverse_grouped.device_groups, b) .collective_ops_creator.create_cross_partition_all_reduce( b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), - {}, (*lhs.state().next_channel_id)++); + {GetAllDevicesInOrder(inverse_grouped.sharding)}, + (*lhs.state().next_channel_id)++); ar->set_sharding(outer_output_tmp_sharding); return PartitionedHlo(ar, output_base_shape, lhs.state()) .Reshard(output_sharding) diff --git a/tensorflow/compiler/xla/service/spmd/fft_handler.cc b/tensorflow/compiler/xla/service/spmd/fft_handler.cc new file mode 100644 index 00000000000..4e1c6a96b81 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/fft_handler.cc @@ -0,0 +1,436 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace spmd { + +namespace { + +// Pad each partition to have size that is multiplication of num_partitions. +// For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2, +// after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in +// partition 1. +absl::optional PadEachPartitionWithHaloExchange( + HloInstruction* hlo, int64 num_partitions, const HloSharding& sharding, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { + int64 size_per_partition = hlo->shape().dimensions().back(); + int64 size_padded_per_partition = + CeilOfRatio(size_per_partition, num_partitions) * num_partitions; + if (size_per_partition == size_padded_per_partition) { + return hlo; + } + // 1. Calculate left_halo size. + // left-halo size is 0 + OffsetCalculation left_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1)); + + // 2. Calculate right_halo size. + // D = size_padded_per_partition + // S = size_per_partition + // i = shard_ordinal + // right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S) + OffsetCalculation right_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + size_padded_per_partition - size_per_partition, + 2 * (size_padded_per_partition - size_per_partition), 1)); + + auto concat = hlo; + // 3. Halo exchange. + auto halo_exchange_result = + ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, + hlo->shape().rank() - 1, sharding, collective_ops_creator, + next_channel_id, b); + + if (halo_exchange_result.has_value()) { + concat = halo_exchange_result.value(); + } else { + return absl::nullopt; + } + + // 4. Slice the valid result. + // Slice offset is (D - S) * i + OffsetCalculation start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + size_padded_per_partition - size_per_partition, 0, 1)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(concat->shape().rank() - 1, + size_padded_per_partition); + auto zero_s32 = + b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(concat->shape().rank(), zero_s32); + auto partition_ordinals = + MakeTiledPartitionOrdinals(sharding, partition_id, b); + slice_offsets[concat->shape().rank() - 1] = + start_offset_on_padded_concat_calculation.Calculate( + partition_ordinals[concat->shape().rank() - 1], b); + return b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); +} + +// If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling, +// the data becomes {0, 2, 1, 3}. +HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo, + int64 num_partitions, + SpmdBuilder* b) { + int64 size_per_partition = hlo->shape().dimensions().back(); + CHECK_EQ(size_per_partition % num_partitions, 0); + auto indices_iota = b->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(S32, {size_per_partition}), 0)); + auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape( + S32, {size_per_partition / num_partitions, num_partitions}), + indices_iota)); + auto transpoe_indices_iota = + b->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape( + S32, {num_partitions, size_per_partition / num_partitions}), + reshape_indices_iota, {1, 0})); + auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), + b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {size_per_partition}), + transpoe_indices_iota)), + /*broadcast_dimensions=*/{1})); + + auto partition_indices = b->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0)); + + auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(partition_indices->shape(), + hlo->shape().element_type()), + b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(partition_indices->shape(), PRED), + one_hot_indices, partition_indices, ComparisonDirection::kEq)))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1); + dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot( + hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config)); + return dot; +} + +// If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and +// num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0} +// and partition 1 will have {1, 3, 5, 0}. +HloInstruction* ShuffleDataWithAllToAll( + HloInstruction* hlo, int64 num_partitions, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + std::vector> groups(1); + std::vector partition_subgroups(num_partitions); + std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0); + groups[0] = partition_subgroups; + auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all( + b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1); + return all_to_all; +} + +HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64 num_partitions, + HloInstruction* partition_id, + SpmdBuilder* b) { + /* n = size_per_replica + m = num_partitions + factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) * + * tf.cast(tf.range(n), dtype=tf.complex64) / + (n * m)) + + */ + auto add_hlo = [&](std::unique_ptr to_add) { + return b->AddInstruction(std::move(to_add)); + }; + int64 per_replica_size = hlo->shape().dimensions().back(); + auto constant_factor = + add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0( + complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size))))); + constant_factor = add_hlo(HloInstruction::CreateBroadcast( + hlo->shape(), constant_factor, /*broadcast_dimensions=*/{})); + auto converted_partition_id = add_hlo(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(partition_id->shape(), + hlo->shape().element_type()), + partition_id)); + // TODO(wangtao): multipy before broadcast. + auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast( + hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{})); + auto exp_operand = add_hlo( + HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply, + constant_factor, broadcast_partition_id)); + auto iota = add_hlo( + HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1)); + exp_operand = add_hlo(HloInstruction::CreateBinary( + hlo->shape(), HloOpcode::kMultiply, exp_operand, iota)); + return add_hlo( + HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand)); +} + +// Sudo code for the while loop: +// def body(dest_transform, dest_core_position, source_transform, +// source_core_position, i): +// factor = tf.exp(-2.0j * np.pi * +// tf.cast(dest_core_position, tf.complex64) * +// tf.cast(source_core_position, tf.complex64) / num_partitions) +// dest_transform += factor * source_transform +// source_core_position = tf.raw_ops.CollectivePermute( +// input=source_core_position, +// source_target_pairs=source_target_pairs, +// name='source_core_position_permute') +// source_transform = tf.raw_ops.CollectivePermute( +// input=source_transform, +// source_target_pairs=source_target_pairs, +// name='source_transform_permute') +// i += 1 +// return (dest_transform, dest_core_position, source_transform, +// source_core_position, i) +HloInstruction* GetFinalFftUsingCollectivePermute( + HloInstruction* hlo, const HloSharding& sharding, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64 num_partitions, HloInstruction* partition_id, int64* next_channel_id, + HloModule* module, SpmdBuilder* b) { + auto iteration = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(partition_id->shape(), + hlo->shape().element_type()), + partition_id)); + // Buid while loop body. + SpmdBuilder body_b("fft_collective_permute_body", hlo); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape( + {hlo->shape(), hlo->shape(), converted_partition_id->shape(), + converted_partition_id->shape(), iteration->shape()}), + "param")); + auto dest_transform = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0)); + auto source_transform = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1)); + auto dest_partition_id = + body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + converted_partition_id->shape(), param, 2)); + auto source_partition_id = + body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + converted_partition_id->shape(), param, 3)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4)); + /* + factor = tf.exp(-2.0j * np.pi * + tf.cast(dest_partiton_id, tf.complex64) * + tf.cast(source_partition_id, tf.complex64) / + num_partitions) dest_transform += factor * source_transform + */ + auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions)))); + + constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary( + constant_factor->shape(), HloOpcode::kMultiply, constant_factor, + dest_partition_id)); + constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary( + constant_factor->shape(), HloOpcode::kMultiply, constant_factor, + source_partition_id)); + auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary( + constant_factor->shape(), HloOpcode::kExp, constant_factor)); + phase_factor = body_b.AddInstruction( + HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {})); + auto phase_adjust_source_transform = + body_b.AddInstruction(HloInstruction::CreateBinary( + hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform)); + dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary( + hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform, + dest_transform)); + // collective permute for source partition_id and source_transfrom. + std::vector> src_dst_pairs; + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 src_device) { + std::vector target_indices(indices.begin(), indices.end()); + target_indices.back() = (indices.back() + 1) % num_partitions; + int64 dst_device = sharding.tile_assignment()(target_indices); + src_dst_pairs.emplace_back(src_device, dst_device); + }); + + source_partition_id = + collective_ops_creator.create_cross_partition_collective_permute( + &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++); + + source_transform = + collective_ops_creator.create_cross_partition_collective_permute( + &body_b, source_transform, src_dst_pairs, (*next_channel_id)++); + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + body_b.AddInstruction( + HloInstruction::CreateTuple({dest_transform, source_transform, + dest_partition_id, source_partition_id, i})); + + // Build while loop conditions. + auto zero = CreateZero(hlo->shape(), b); + SpmdBuilder cond_b("fft_collective_permute_condition", hlo); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape( + {hlo->shape(), hlo->shape(), converted_partition_id->shape(), + converted_partition_id->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions))), + ComparisonDirection::kLt)); + + // Build while loop. + auto while_loop = b->AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()), + module->AddEmbeddedComputation(body_b.Build()), + b->AddInstruction( + HloInstruction::CreateTuple({zero, hlo, converted_partition_id, + converted_partition_id, iteration})))); + + return b->AddInstruction( + HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0)); +} + +// Slice valid data in each partition. +HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape, + SpmdBuilder* b) { + std::vector start_indices(target_shape.rank(), 0); + std::vector strides(target_shape.rank(), 1); + return b->AddInstruction(HloInstruction::CreateSlice( + target_shape, hlo, start_indices, target_shape.dimensions(), strides)); +} + +} // namespace + +// Distributed FFT using the algorithm described in go/tpu-spmd-fft. +Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) { + if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) { + return DefaultAction(hlo); + } + + // Only support input_length equals fft_length's case. + int64 input_length = hlo->operand(0)->shape().dimensions().back(); + int64 fft_length = hlo->fft_length().back(); + if (input_length != fft_length || input_length % num_partitions_ != 0) { + return DefaultAction(hlo); + } + + // Support partition at the last dimension only. + if (!hlo->has_sharding() || + hlo->sharding().tile_assignment().dimensions().back() != + num_partitions_) { + return DefaultAction(hlo); + } + + auto partitioned_input = + GetPartitionedHlo(hlo->operand(0)) + .PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_)); + + // 1.a. Use right halo exchange to shuffle data first and slice with + // valid data. Data shuffling ensures an in-order transform that the sequences + // of data before and after the transform are the same. The data shuffling + // requires the size of data per partition is divisible by the number of + // partitions. For example, If input is {0, 1, 2, 3, 4, 5} and + // num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and + // partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data. + // Zeros paddings append zeros to the end of the full data. + auto result = partitioned_input.hlo(); + auto padded_hlo = PadEachPartitionWithHaloExchange( + partitioned_input.hlo(), num_partitions_, hlo->sharding(), + partitioned_input.state().collective_ops_creator, + partitioned_input.state().next_channel_id, + partitioned_input.state().partition_id, partitioned_input.state().b); + + if (padded_hlo.has_value()) { + result = padded_hlo.value(); + } + + // 1.b Shuffle data within each partition using one hot and matmul. + // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling, + // the data becomes {0, 2, 1, 3}. + result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_, + partitioned_input.state().b); + // 1.c all-to-all + // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and + // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0} + // and partition 1 will have {1, 3, 5, 0}. + result = ShuffleDataWithAllToAll( + result, num_partitions_, partitioned_input.state().collective_ops_creator, + partitioned_input.state().next_channel_id, partitioned_input.state().b); + // 1.d Slice valid data in each partition. + result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_); + + // 2. Do local fft transform. + auto partitioned_fft_length = hlo->fft_length(); + partitioned_fft_length.back() /= num_partitions_; + result = b_.AddInstruction(HloInstruction::CreateFft( + result->shape(), result, hlo->fft_type(), partitioned_fft_length)); + + // Multiply by correct factor for local phase ajustment. + auto correction_factor = GetCorrectionFactor( + result, num_partitions_, partitioned_input.state().partition_id, + partitioned_input.state().b); + result = b_.AddInstruction(HloInstruction::CreateBinary( + result->shape(), HloOpcode::kMultiply, result, correction_factor)); + + // 3. Second phase FFT with collective permute. fft_length = num_partitions. + result = GetFinalFftUsingCollectivePermute( + result, hlo->sharding(), partitioned_input.state().collective_ops_creator, + num_partitions_, partitioned_input.state().partition_id, + partitioned_input.state().next_channel_id, module_, + partitioned_input.state().b); + + result->set_sharding(hlo->sharding()); + auto partitioned_fft = + PartitionedHlo(result, hlo->shape(), partitioned_input.state()); + SetPartitionedHlo(hlo, partitioned_fft); + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index ceb81330639..0c56800f945 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1000,13 +1000,6 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( target.tile_assignment().dimensions().back()); } - // Get per_group partitioner state. - std::vector group_dims(sharding().tile_assignment().num_dimensions() - - 1); - std::iota(group_dims.begin(), group_dims.end(), 0); - auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - state_, sharding_grouped.device_groups, state_.b); // 2. Get the padded_hlo, do right halo exchange if needed. auto padded_hlo = PadFromPartialReplicateShape( hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims, @@ -1017,20 +1010,24 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( } // 3. Slice out the tile from replicate ones. auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding); - // device assignment within each group is sorted in - // HloSharding::PartialTile, thus partiton_id within each group can be - // matched with the order in tile_assignment. - Array tiling_assignment(tiling_dim_factors); - tiling_assignment.FillIota(0); + // Since we are just slicing, we can just use the differences between the new + // and old offsets in the full shape as the dynamic-slice offsets. + auto padded_base_shape = shard_shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, padded_base_shape.dimensions(i) * + temp_target_sharding.tile_assignment().dim(i)); + } + auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding, + state_.partition_id, state_.b); + auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(), + state_.partition_id, state_.b); + for (int64 i = 0; i < offsets.size(); ++i) { + offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary( + offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i])); + } auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( - shard_shape, padded_hlo.value(), - MakePartitionOffsets(padded_hlo.value()->shape(), - target.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(tiling_assignment) - : HloSharding::Tile(tiling_assignment), - per_group_partitioner_state.partition_id, - per_group_partitioner_state.b), - shard_shape.dimensions())); + shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions())); slice->set_sharding(temp_target_sharding); auto result = PartitionedHlo(slice, base_shape_, state_); // If temp_target_sharding's device assignment is different from target, diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 86c1a97b0d2..a140e7a2b9c 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -382,6 +382,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* hlo) override; Status HandleDynamicSlice(HloInstruction* hlo) override; Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + Status HandleFft(HloInstruction* hlo) override; Status HandleGather(HloInstruction* hlo) override; Status HandleGetTupleElement(HloInstruction* hlo) override; Status HandleInfeed(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 84c40e888a3..90092645f16 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3378,6 +3378,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); } +TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate} + %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0} + ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]")); + auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[16,256,1024]"))); +} + TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { const char* const hlo_string = R"( HloModule module @@ -5115,7 +5139,7 @@ ENTRY entry { op::DynamicSlice( op::Pad(op::Concatenate(multiply, right_halo), op::Constant()), op::Reshape(), op::Constant()), - op::Reshape(), op::Constant())); + op::Subtract(), op::Subtract())); auto add_rhs = AllOf(op::Shape("f32[2,3]"), op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), op::Reshape(), op::Constant())); @@ -5167,9 +5191,10 @@ ENTRY entry { AllOf(op::Shape("f32[4,8]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant()))); - auto tiled = AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, - op::Constant(), op::Reshape()))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), + op::Subtract()))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -5223,9 +5248,10 @@ ENTRY entry { AllOf(op::Shape("f32[4,8]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant()))); - auto tiled = AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, - op::Constant(), op::Reshape()))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), + op::Subtract()))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -5250,9 +5276,10 @@ ENTRY entry { AllOf(op::Shape("f32[8,8]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Constant()))); - auto tiled = AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, - op::Reshape(), op::Reshape()))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), + op::Subtract()))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -5309,7 +5336,7 @@ ENTRY entry { auto tiled = AllOf(op::Shape("f32[4,4]"), op::Copy(op::CollectivePermute(op::DynamicSlice( - partially_replicated, op::Reshape(), op::Constant())))); + partially_replicated, op::Subtract(), op::Subtract())))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -5816,8 +5843,8 @@ ENTRY entry { op::Shape("f32[8,801,1,1024]")); auto resharded_lhs = AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape( - op::Pad(op::DynamicSlice(lhs, op::Constant(), op::Constant(), - op::Constant(), op::Reshape()), + op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(), + op::Subtract(), op::Subtract()), op::Constant()))))), op::Shape("f32[16,401,1,512]")); auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), @@ -6110,6 +6137,43 @@ ENTRY entry { op::Shape("f32[8,105,210,32]"))); } +TEST_F(SpmdPartitioningTest, Fft3D) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = c64[1,1,6] + constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}), + sharding={devices=[1,1,2]0,1} + ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6}, + sharding={devices=[1,1,2]0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(), + op::Constant(), op::Reshape()), + op::Shape("c64[1,1,3]")); + auto padded_input = + AllOf(op::DynamicSlice( + op::Concatenate(input, op::CollectivePermute(op::Slice())), + op::Constant(), op::Constant(), op::Reshape()), + op::Shape("c64[1,1,4]")); + + auto shuffled_input = + AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))), + op::Shape("c64[1,1,3]")); + + auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]")); + + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple( + _, op::Multiply(local_fft, op::Exp()), _, _, _))), + op::Shape("c64[1,1,3]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 6a19a1fac09..0c9a2f3ab54 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -39,6 +39,15 @@ class Shape { // Construct a shape from a ShapeProto. explicit Shape(const ShapeProto& shape_proto); + Shape(PrimitiveType element_type, absl::Span dimensions, + absl::Span dynamic_dimensions, + std::vector tuple_shapes) + : element_type_(element_type), + dimensions_(dimensions.begin(), dimensions.end()), + dynamic_dimensions_(dynamic_dimensions.begin(), + dynamic_dimensions.end()), + tuple_shapes_(std::move(tuple_shapes)) {} + // Returns a ShapeProto representation of the Shape. ShapeProto ToProto() const; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 238879ebdc0..0c877bf6102 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1074,8 +1074,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, } } - return std::make_tuple(!deleted_indices.empty() || !inserted_indices.empty(), - deleted_indices, inserted_indices); + return std::make_tuple(true, deleted_indices, inserted_indices); } /* static */ std::vector> diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 414b53d4f67..4e2030667ee 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -558,8 +558,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1))); EXPECT_FALSE(std::get<0>( ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); - EXPECT_FALSE(std::get<0>( - ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape0))); } TEST(ShapeUtilTest, ForEachIndex) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 201c0da87f1..1a95f2fb549 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -122,7 +122,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( bfloat16 interval = static_cast(0.25); std::vector counts(static_cast((high - low) / interval), 0); - constexpr int64 count = 100; + constexpr int64 count = 1000; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); result.EachCell([&](absl::Span, bfloat16 value) { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 7da8d2cb84d..d0053e1b6fb 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -250,6 +250,7 @@ enum ProfileType { INVALID = 0; WINDOW = 1; FLAG = 2; + INTEGER = 3; } // Symbolization metadata for HLO Instructions. diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 724cfe38d54..4f3f4b36970 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -73,7 +73,9 @@ tf_cuda_cc_test( "--xla_test_device=XLA_GPU", "--xla_platform=GPU", ], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171319142): re-enable. + ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_gpu_device", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f5bbff7827d..3818e735091 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -46,7 +46,6 @@ # # Public mobile targets, e.g. for Android: # -# filegroup ":android_proto_srcs" - Protos # cc_library ":portable_tensorflow_lib" - Native library # cc_library ":portable_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. @@ -88,7 +87,7 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "if_nccl") # buildifier: disable=same-origin-load -load("@tf_toolchains//macros:cpp.bzl", "tensorflow_opensource_extra_deps") +load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") @@ -290,12 +289,6 @@ cc_library( ], ) -alias( - name = "framework_bounds_check", - actual = "//tensorflow/core/framework:bounds_check", - visibility = ["//tensorflow/core/kernels:friends"], -) - # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -455,6 +448,7 @@ tf_cuda_library( "//tensorflow/core/framework:reader_op_kernel.h", "//tensorflow/core/framework:register_types.h", "//tensorflow/core/framework:register_types_traits.h", + "//tensorflow/core/framework:registration_options.h", "//tensorflow/core/framework:resource_mgr.h", "//tensorflow/core/framework:resource_op_kernel.h", "//tensorflow/core/framework:rng_alg.h", @@ -840,13 +834,6 @@ tf_cuda_library( # ----------------------------------------------------------------------------- # Public Android targets -# List of protos we want on android -filegroup( - name = "android_proto_srcs", - srcs = CORE_PROTO_SRCS, - visibility = ["//visibility:public"], -) - # Sources required to build the TensorFlow framework without the runtime on # mobile platforms. This is essentially the sources required to build # tensorflow/core/framework:tensor without using granular targets. @@ -1656,7 +1643,7 @@ cc_header_only_library( ":lib", ":lib_internal", ":version_lib", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform/default/build_config:platformlib", ], ) @@ -1708,6 +1695,7 @@ tf_cuda_library( "//tensorflow/core/framework:attr_value_proto_text", "//tensorflow/core/framework:attr_value_util", "//tensorflow/core/framework:bfloat16", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:kernel_shape_util", "//tensorflow/core/framework:node_def_util", @@ -1721,7 +1709,6 @@ tf_cuda_library( "//tensorflow/core/framework:shape_inference", "//tensorflow/core/framework:tensor", "//tensorflow/core/framework:tensor_shape", - "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", @@ -1819,7 +1806,7 @@ tf_cuda_library( ":function_ops_op_lib", ":functional_grad", ":functional_ops_op_lib", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:required", ]), alwayslink = 1, @@ -1872,12 +1859,6 @@ cc_library( ], ) -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "tensor_testutil", - actual = "//tensorflow/core/framework:tensor_testutil", -) - # Main program for tests alias( name = "test_main", @@ -1885,14 +1866,6 @@ alias( visibility = ["//tensorflow:internal"], ) -test_suite( - name = "low_level_tests", - tests = [ - ":low_level_library_tests", - "//tensorflow/core/platform:low_level_library_tests", - ], -) - tf_cc_tests( name = "low_level_library_tests", size = "small", @@ -1913,7 +1886,6 @@ tf_cc_tests( "//tensorflow/core/lib/random:legacy_lib_random_tests", "//tensorflow/core/lib/strings:legacy_low_level_library_tests", ], - create_named_test_suite = True, deps = [ ":lib", ":lib_internal", @@ -1949,26 +1921,10 @@ tf_cc_test( ], ) -test_suite( - name = "platform_tests", - tests = [ - "//tensorflow/core/platform:abi_test", - "//tensorflow/core/platform:env_test", - "//tensorflow/core/platform:fake_python_env_test", - "//tensorflow/core/platform:file_system_test", - "//tensorflow/core/platform:numa_test", - "//tensorflow/core/platform:platform_strings_test", - "//tensorflow/core/platform:rocm_rocdl_path_test", - "//tensorflow/core/platform:setround_test", - "//tensorflow/core/platform:unbounded_work_queue_test", - "//tensorflow/core/platform:vmodule_test", - ], -) - tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], - data = glob(["lib/jpeg/testdata/*.jpg"]), + data = ["//tensorflow/core/lib/jpeg/testdata"], deps = [ ":jpeg_internal", ":lib", @@ -2017,15 +1973,6 @@ tf_cc_test( ], ) -test_suite( - name = "higher_level_tests", - tests = [ - ":core_higher_level_tests", - "//tensorflow/core/framework:higher_level_tests", - "//tensorflow/core/util:higher_level_tests", - ], -) - tf_cc_tests( name = "core_higher_level_tests", size = "small", @@ -2044,7 +1991,6 @@ tf_cc_tests( "//tensorflow/core/graph:validate_test.cc", "//tensorflow/core/util/sparse:higher_level_tests_group", ], - create_named_test_suite = True, linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -2166,23 +2112,12 @@ filegroup( "//tensorflow/core/lib/ssim:testdata", "//tensorflow/core/lib/psnr:testdata", # JPEG data - "lib/jpeg/testdata/jpeg_merge_test1.jpg", - "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg", - # JPEG data for jpeg benchmark. - "lib/jpeg/testdata/small.jpg", - "lib/jpeg/testdata/medium.jpg", - # Corrupted JPEG files for tests - "lib/jpeg/testdata/bad_huffman.jpg", - "lib/jpeg/testdata/corrupt.jpg", - # -- hand-edited variant: stops at line 0 - "lib/jpeg/testdata/corrupt34_2.jpg", - # -- hand-edited variant: stops at line 4 - "lib/jpeg/testdata/corrupt34_3.jpg", - # -- hand-edited variant: stops after a restart marker - "lib/jpeg/testdata/corrupt34_4.jpg", + "//tensorflow/core/lib/jpeg/testdata", # GIF data "lib/gif/testdata/lena.gif", "lib/gif/testdata/scan.gif", + "lib/gif/testdata/red_black.gif", + "lib/gif/testdata/squares.gif", # GIF data with optimization "lib/gif/testdata/optimized.gif", # BMP data diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt index c534425eb24..7174c8d3daf 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt @@ -47,5 +47,11 @@ constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or PNG files. Alternately, set the expand_animations argument of this function to False, in which case the op will return 3-dimensional tensors and will truncate animated GIF files to the first frame. + +*NOTE*: If the first frame of an animated GIF does not occupy the entire +canvas (maximum frame width x maximum frame height), then it fills the +unoccupied areas (in the first frame) with zeros (black). For frames after the +first frame that does not occupy the entire canvas, it uses the previous +frame to fill the unoccupied areas. END } diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt new file mode 100644 index 00000000000..066d6b5eae4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt @@ -0,0 +1,38 @@ +op { + graph_op_name: "RaggedTensorToVariantGradient" + visibility: HIDDEN + in_arg { + name: "encoded_ragged_grad" + description: <IsCancelled() || cancel_mgr->IsCancelling()); +} +} // namespace + /*static*/ int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts, int64 num_chunks) { @@ -215,43 +224,74 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, BaseCollectiveExecutor::~BaseCollectiveExecutor() {} void BaseCollectiveExecutor::StartAbort(const Status& s) { - VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s; - cem_->GetParamResolver()->StartAbort(s); - remote_access_->StartAbort(s); - if (cem_->GetNcclCommunicator() != nullptr) { - cem_->GetNcclCommunicator()->StartAbort(s); + Status status; + { + mutex_lock l(status_mu_); + if (!status_.ok()) { + VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: " + << s; + return; + } + status_ = StatusGroup::MakeDerived(Status( + s.code(), + absl::StrCat( + "Collective ops is aborted by: ", s.error_message(), + "\nThe error could be from a previous operation. Restart your " + "program to reset."))); + status = status_; } + LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s; + cem_->GetParamResolver()->StartAbort(status); + remote_access_->StartAbort(status); + if (cem_->GetNcclCommunicator() != nullptr) { + cem_->GetNcclCommunicator()->StartAbort(status); + } +} + +Status BaseCollectiveExecutor::GetStatus(const Status& s) { + if (s.ok()) return s; + mutex_lock l(status_mu_); + // If the collective executor is already aborted, use the aborted status + // which is more likely the actual error instead of an artifact of an + // abortion. + if (!status_.ok()) { + VLOG(2) << "Overriding status with collective ops executor status. " + "Original status: " + << s; + return status_; + } + return s; } void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, const string& exec_key, StatusCallback done) { + // See CompleteParamsAsync() how done() and the timeout callback interacts. const auto is_callback_called = std::make_shared>(false); - - // On any individual collective Op failure we need to abort the - // BufRendezvous so that other Ops in the instance don't hang - // waiting for transmissions that will never happen. - StatusCallback done_safe = [this, done, is_callback_called](const Status& s) { - auto should_call_callback = !is_callback_called->exchange(true); - if (should_call_callback) { - if (!s.ok()) { - remote_access_->buf_rendezvous()->StartAbort(s); + auto done_safe = [this, done, ctx, is_callback_called](const Status& s) { + bool called = is_callback_called->exchange(true); + if (!called) { + if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) { + // This is a collective error. Abort CollectiveExecutor so that this + // error can propagate to other workers. + StartAbort(s); } - done(s); + done(GetStatus(s)); } }; - auto timeout_microseconds = static_cast( col_params.instance.impl_details.timeout_seconds * 1'000'000); if (timeout_microseconds > 0) { // TODO(xldrx): Share the timeout watchdog thread among collectives. SchedNonBlockingClosureAfter( - timeout_microseconds, [is_callback_called, done_safe] { - if (!is_callback_called->load()) { - auto status = Status(error::DEADLINE_EXCEEDED, - "Collective has timed out during execution."); - done_safe(status); + timeout_microseconds, [this, is_callback_called, done] { + bool called = is_callback_called->exchange(true); + if (!called) { + Status status(error::DEADLINE_EXCEEDED, + "Collective has timed out during execution."); + StartAbort(status); + done(status); } }); } @@ -307,31 +347,43 @@ void BaseCollectiveExecutor::CompleteParamsAsync( const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) { cp->group.gpu_ring_order = *gpu_ring_order_; + // We need to make sure that when the timeout callback executes, + // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done() + // is called, CollectiveExecutorMgr may be destructed and we don't have a way + // to keep it without making the ownerships more complicated. Therefore if the + // timeout callback executes, done_safe will become a no-op and the timeout + // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared>(false); - auto done_with_timeout = done; + auto done_safe = [this, is_callback_called, cancel_mgr, + done](const Status& s) { + bool called = is_callback_called->exchange(true); + if (!called) { + if (!s.ok() && !IsCancelled(cancel_mgr)) { + // This is a collective error. Abort CollectiveExecutor so that this + // error can propagate to other workers. + StartAbort(s); + } + done(GetStatus(s)); + } + }; auto timeout_microseconds = static_cast(cp->instance.impl_details.timeout_seconds * 1'000'000); if (timeout_microseconds > 0) { // TODO(xldrx): Share the timeout watchdog thread among collectives. SchedNonBlockingClosureAfter( - timeout_microseconds, [is_callback_called, done] { - auto should_call_callback = !is_callback_called->exchange(true); - if (should_call_callback) { - auto status = - Status(error::DEADLINE_EXCEEDED, - "Collective has timed out waiting for other workers."); + timeout_microseconds, [this, is_callback_called, done]() { + bool called = is_callback_called->exchange(true); + if (!called) { + Status status( + error::DEADLINE_EXCEEDED, + "Collective has timed out waiting for other workers."); + StartAbort(status); done(status); } }); - done_with_timeout = [is_callback_called, done](const Status& s) { - auto should_call_callback = !is_callback_called->exchange(true); - if (should_call_callback) { - done(s); - } - }; } cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, - done_with_timeout); + done_safe); } Status BaseCollectiveExecutor::CreateCollective( diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index 4081b887add..142c825df55 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -108,7 +108,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { ~BaseCollectiveExecutor() override; - void StartAbort(const Status& s) override; + void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_); void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, const string& exec_key, StatusCallback done) override; @@ -148,6 +148,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor { // collective instance key -> number of local devices for which NCCL ops have // been launched. std::unordered_map launched_ TF_GUARDED_BY(launch_mu_); + mutex status_mu_; + Status status_ TF_GUARDED_BY(status_mu_); private: Status CreateCollective(const CollectiveParams& col_params, @@ -155,6 +157,9 @@ class BaseCollectiveExecutor : public CollectiveExecutor { // Check if all ops on which this collective depends on have launched. bool CheckDependencies(const CollectiveParams& col_params) TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_); + // Tries to return the status that is the original error. It returns the + // aborted status if the collective executor is aborted. + Status GetStatus(const Status& s) TF_LOCKS_EXCLUDED(status_mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/buf_rendezvous.cc b/tensorflow/core/common_runtime/buf_rendezvous.cc index 49cc9fd3db8..fc05ad0dd96 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous.cc @@ -20,10 +20,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" namespace tensorflow { +namespace { +void DeregisterCancellation(BufRendezvous::Hook* h) { + if (h->cancellation_manager != nullptr) { + h->cancellation_manager->DeregisterCallback(h->cancellation_token); + h->cancellation_manager = nullptr; + h->cancellation_token = CancellationManager::kInvalidToken; + } +} +} // namespace BufRendezvous::~BufRendezvous() { mutex_lock l(mu_); @@ -51,6 +61,9 @@ void BufRendezvous::StartAbort(const Status& s) { void BufRendezvous::PurgeTable(const Status& s, HookTable* table) { for (auto& it : *table) { Hook* h = it.second; + if (h->cancellation_manager != nullptr) { + h->cancellation_manager->TryDeregisterCallback(h->cancellation_token); + } if (h->cons_cb != nullptr) { h->cons_cb(s, nullptr); } @@ -73,7 +86,8 @@ string BufRendezvous::Hook::DebugString() const { void BufRendezvous::ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx, const Tensor* v, const AllocatorAttributes& attr, - const ProducerCallback& done) { + const ProducerCallback& done, + CancellationManager* cancellation_manager) { Hook* h = nullptr; Status providebuf_status; do { @@ -82,9 +96,13 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev, providebuf_status = status_; break; } else { + CancellationToken cancellation_token = CancellationManager::kInvalidToken; auto it = hook_table_.find(key); if (it == hook_table_.end()) { - h = new Hook; + if (cancellation_manager != nullptr) { + cancellation_token = cancellation_manager->get_cancellation_token(); + } + h = new Hook(cancellation_manager, cancellation_token); it = hook_table_.insert(std::make_pair(key, h)).first; } else { if (it->second->prod_cb != nullptr) { @@ -100,15 +118,27 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev, h->prod_value = v; h->prod_attr = attr; h->prod_cb = done; - // If consumer is waiting, kick off right away, removing Hook from table. if (h->cons_cb != nullptr) { + // If consumer is waiting, kick off right away, removing Hook from + // table. hook_table_.erase(it); } else { + if (cancellation_manager != nullptr && + !cancellation_manager->RegisterCallback( + cancellation_token, [this, key]() { CancelHook(key); })) { + // Register cancellation callback with CancellationManager. If it is + // already cancelled, call done immediately with cancelled status. + providebuf_status = errors::Cancelled( + "Operation was cancelled for BufRendezvous key ", key); + hook_table_.erase(it); + delete h; + } h = nullptr; } } } while (false); if (h) { + DeregisterCancellation(h); h->cons_cb(Status::OK(), h); } if (!providebuf_status.ok()) { @@ -118,7 +148,8 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev, void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, const uint64 device_incarnation, - const ConsumerCallback& done) { + const ConsumerCallback& done, + CancellationManager* cancellation_manager) { // Check the incarnation in the request matches the current device // incarnation of the producer. Device* device; @@ -157,13 +188,26 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, existing_hook->cons_cb = done; } else { // Hang consumer callback on the Hook. - Hook* h = new Hook; - hook_table_[key] = h; - h->cons_cb = done; - return; + CancellationToken cancellation_token = CancellationManager::kInvalidToken; + bool already_cancelled = false; + if (cancellation_manager != nullptr) { + cancellation_token = cancellation_manager->get_cancellation_token(); + already_cancelled = !cancellation_manager->RegisterCallback( + cancellation_token, [this, key]() { CancelHook(key); }); + } + if (already_cancelled) { + consumebuf_status = errors::Cancelled( + "Operation was cancelled for BufRendezvous key ", key); + } else { + Hook* h = new Hook(cancellation_manager, cancellation_token); + h->cons_cb = done; + it = hook_table_.insert(std::make_pair(key, h)).first; + return; + } } } while (false); if (existing_hook) { + DeregisterCancellation(existing_hook); existing_hook->cons_cb(Status::OK(), existing_hook); return; } @@ -173,6 +217,28 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, } } +void BufRendezvous::CancelHook(const string& key) { + Hook* h = nullptr; + { + mutex_lock l(mu_); + auto it = hook_table_.find(key); + if (it == hook_table_.end()) return; + h = it->second; + hook_table_.erase(it); + } + if (h != nullptr) { + auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key ", + key); + if (h->prod_cb != nullptr) { + h->prod_cb(s); + } + if (h->cons_cb != nullptr) { + h->cons_cb(s, /*Hook=*/nullptr); + } + delete h; + } +} + /*static*/ void BufRendezvous::DoneWithHook(Hook* h) { h->prod_cb(Status::OK()); diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h index 74857e46a53..c8cd527f4ae 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous.h +++ b/tensorflow/core/common_runtime/buf_rendezvous.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" @@ -66,20 +67,30 @@ class BufRendezvous { AllocatorAttributes prod_attr; ProducerCallback prod_cb; ConsumerCallback cons_cb; - Hook() + CancellationManager* cancellation_manager; + CancellationToken cancellation_token; + explicit Hook(CancellationManager* cancellation_manager, + CancellationToken cancellation_token) : prod_dev(nullptr), prod_ctx(nullptr), prod_value(nullptr), prod_cb(nullptr), - cons_cb(nullptr) {} + cons_cb(nullptr), + cancellation_manager(cancellation_manager), + cancellation_token(cancellation_token) {} string DebugString() const; }; // Called to advertise availability of a Tensor value corresponding // to key. That value must stay valid until done is called. + // + // If a non-null cancellation manager is provided, this function registers a + // callback to delete the hook and invoke provider/consumer callbacks with + // cancelled error. void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx, const Tensor* v, const AllocatorAttributes& attr, - const ProducerCallback& done); + const ProducerCallback& done, + CancellationManager* cancellation_manager); // Called to request access to a Tensor value corresponding to key. // Consumer is provided with a Hook as soon as available. @@ -88,8 +99,17 @@ class BufRendezvous { // `device` that produced this value matches the `incarnation` expected by the // consumer, and invokes `done` with `FailedPrecondition` status and // `nullptr` hook if it does not match. + // + // If a non-null cancellation manager is provided, this function registers a + // callback to delete the hook and invoke provider/consumer callbacks with + // cancelled error. void ConsumeBuf(const string& key, const string& device, - const uint64 incarnation, const ConsumerCallback& done); + const uint64 incarnation, const ConsumerCallback& done, + CancellationManager* cancellation_manager); + + // Cancel the rendezvous entry corresponding to `key`. Triggered by the + // cancellation manager. No-op if the rendezvous was already successful. + void CancelHook(const string& key); // Consumer must call this function when it's done reading the Hook provided // by the ConsumerCallback. This function will invoke the producer callback diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc index 270165114f7..b1c7b6a80e2 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous_test.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc @@ -68,6 +68,7 @@ class BufRendezvousTest : public ::testing::Test { DeviceContext* fake_device_context_; std::unique_ptr dev_mgr_; std::unique_ptr br_; + CancellationManager cm_; static const string* const kDefaultKey; static const string* const kDefaultDeviceName; static const uint64 kDefaultIncarnation; @@ -90,19 +91,22 @@ TEST_F(BufRendezvousTest, CorrectUseProducerFirst) { prod_status = s; prod_callback_called = true; note.Notify(); - }); + }, + &cm_); EXPECT_FALSE(prod_callback_called); - br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [this, &cons_status, &cons_callback_called]( - const Status& s, BufRendezvous::Hook* h) { - cons_status = s; - cons_callback_called = true; - ASSERT_TRUE(h != nullptr); - EXPECT_EQ(h->prod_dev, default_device_); - EXPECT_EQ(h->prod_ctx, fake_device_context_); - EXPECT_EQ(h->prod_value, &a_); - br_->DoneWithHook(h); - }); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [this, &cons_status, &cons_callback_called](const Status& s, + BufRendezvous::Hook* h) { + cons_status = s; + cons_callback_called = true; + ASSERT_TRUE(h != nullptr); + EXPECT_EQ(h->prod_dev, default_device_); + EXPECT_EQ(h->prod_ctx, fake_device_context_); + EXPECT_EQ(h->prod_value, &a_); + br_->DoneWithHook(h); + }, + &cm_); EXPECT_TRUE(cons_callback_called); note.WaitForNotification(); EXPECT_TRUE(prod_callback_called); @@ -116,17 +120,19 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) { bool prod_callback_called = false; bool cons_callback_called = false; Notification note; - br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [this, &cons_status, &cons_callback_called]( - const Status& s, BufRendezvous::Hook* h) { - cons_status = s; - cons_callback_called = true; - ASSERT_TRUE(h != nullptr); - EXPECT_EQ(h->prod_dev, default_device_); - EXPECT_EQ(h->prod_ctx, fake_device_context_); - EXPECT_EQ(h->prod_value, &a_); - br_->DoneWithHook(h); - }); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [this, &cons_status, &cons_callback_called](const Status& s, + BufRendezvous::Hook* h) { + cons_status = s; + cons_callback_called = true; + ASSERT_TRUE(h != nullptr); + EXPECT_EQ(h->prod_dev, default_device_); + EXPECT_EQ(h->prod_ctx, fake_device_context_); + EXPECT_EQ(h->prod_value, &a_); + br_->DoneWithHook(h); + }, + &cm_); EXPECT_FALSE(cons_callback_called); br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, @@ -134,7 +140,8 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) { prod_status = s; prod_callback_called = true; note.Notify(); - }); + }, + &cm_); EXPECT_TRUE(cons_callback_called); note.WaitForNotification(); EXPECT_TRUE(prod_callback_called); @@ -144,17 +151,19 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) { TEST_F(BufRendezvousTest, ErrorDuplicatePut) { bool prod_callback_called = false; - br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&prod_callback_called](const Status& s) { - prod_callback_called = true; - }); + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [&prod_callback_called](const Status& s) { prod_callback_called = true; }, + &cm_); Status bad_status; Notification note; - br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&bad_status, ¬e](const Status& s) { - bad_status = s; - note.Notify(); - }); + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [&bad_status, ¬e](const Status& s) { + bad_status = s; + note.Notify(); + }, + &cm_); note.WaitForNotification(); EXPECT_FALSE(bad_status.ok()); EXPECT_EQ(absl::StrCat("BufRendezvous::ProvideBuf already called for key ", @@ -166,11 +175,13 @@ TEST_F(BufRendezvousTest, ErrorDuplicatePut) { TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) { Status cons_status; - br_->ConsumeBuf(*kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&cons_status](const Status& s, BufRendezvous::Hook* h) { - cons_status = s; - EXPECT_EQ(h, nullptr); - }); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [&cons_status](const Status& s, BufRendezvous::Hook* h) { + cons_status = s; + EXPECT_EQ(h, nullptr); + }, + &cm_); EXPECT_TRUE(cons_status.ok()); br_.reset(); EXPECT_FALSE(cons_status.ok()); @@ -188,12 +199,15 @@ TEST_F(BufRendezvousTest, AbortNonEmpty) { [&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) { cons_status = s; cons_note.Notify(); - }); - br_->ProvideBuf("key1", default_device_, fake_device_context_, &a_, aa_, - [&prod_note, &prod_status](const Status& s) { - prod_status = s; - prod_note.Notify(); - }); + }, + &cm_); + br_->ProvideBuf( + "key1", default_device_, fake_device_context_, &a_, aa_, + [&prod_note, &prod_status](const Status& s) { + prod_status = s; + prod_note.Notify(); + }, + &cm_); br_->StartAbort(errors::Internal("Falling sky detected")); prod_note.WaitForNotification(); cons_note.WaitForNotification(); @@ -218,12 +232,15 @@ TEST_F(BufRendezvousTest, UseAfterAbort) { [&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) { cons_status = s; cons_note.Notify(); - }); - br_->ProvideBuf("key1", default_device_, fake_device_context_, &a_, aa_, - [&prod_note, &prod_status](const Status& s) { - prod_status = s; - prod_note.Notify(); - }); + }, + &cm_); + br_->ProvideBuf( + "key1", default_device_, fake_device_context_, &a_, aa_, + [&prod_note, &prod_status](const Status& s) { + prod_status = s; + prod_note.Notify(); + }, + &cm_); prod_note.WaitForNotification(); cons_note.WaitForNotification(); EXPECT_FALSE(prod_status.ok()); @@ -237,18 +254,161 @@ TEST_F(BufRendezvousTest, UseAfterAbort) { TEST_F(BufRendezvousTest, DeviceIncarnationMismatch) { Status cons_status; Notification note; - br_->ProvideBuf(*kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [](const Status&) {}); + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [](const Status&) {}, /*cancellation_manager=*/nullptr); const uint64 incorrect_incarnation = 23456; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, incorrect_incarnation, [¬e, &cons_status](const Status& s, BufRendezvous::Hook* h) { cons_status = s; note.Notify(); - }); + }, + /*cancellation_manager=*/nullptr); note.WaitForNotification(); EXPECT_TRUE(errors::IsFailedPrecondition(cons_status)); } +TEST_F(BufRendezvousTest, ProvideThenCancel) { + Status status; + Notification note; + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [&status, ¬e](const Status& s) { + status = s; + note.Notify(); + }, + &cm_); + cm_.StartCancel(); + note.WaitForNotification(); + EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_NE( + status.error_message().find(absl::StrCat( + "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), + string::npos); +} + +TEST_F(BufRendezvousTest, CancelThenProvide) { + Status status; + Notification note; + cm_.StartCancel(); + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [&status, ¬e](const Status& s) { + status = s; + note.Notify(); + }, + &cm_); + note.WaitForNotification(); + EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_NE( + status.error_message().find(absl::StrCat( + "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), + string::npos); +} + +TEST_F(BufRendezvousTest, ConsumeThenCancel) { + Status status; + Notification note; + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [&status, ¬e](const Status& s, BufRendezvous::Hook* h) { + status = s; + note.Notify(); + }, + &cm_); + cm_.StartCancel(); + note.WaitForNotification(); + EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_NE( + status.error_message().find(absl::StrCat( + "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), + string::npos); +} + +TEST_F(BufRendezvousTest, CancelThenConsume) { + Status status; + Notification note; + cm_.StartCancel(); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [&status, ¬e](const Status& s, BufRendezvous::Hook* h) { + status = s; + note.Notify(); + }, + &cm_); + note.WaitForNotification(); + EXPECT_TRUE(errors::IsCancelled(status)); + EXPECT_NE( + status.error_message().find(absl::StrCat( + "Operation was cancelled for BufRendezvous key ", *kDefaultKey)), + string::npos); +} + +TEST_F(BufRendezvousTest, ProvideConsumeThenCancel) { + Status prod_status; + Status cons_status; + bool prod_callback_called = false; + bool cons_callback_called = false; + Notification note; + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [¬e, &prod_status, &prod_callback_called](const Status& s) { + prod_status = s; + prod_callback_called = true; + note.Notify(); + }, + &cm_); + EXPECT_FALSE(prod_callback_called); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [this, &cons_status, &cons_callback_called](const Status& s, + BufRendezvous::Hook* h) { + cons_status = s; + cons_callback_called = true; + ASSERT_TRUE(h != nullptr); + EXPECT_EQ(h->prod_dev, default_device_); + EXPECT_EQ(h->prod_ctx, fake_device_context_); + EXPECT_EQ(h->prod_value, &a_); + br_->DoneWithHook(h); + }, + &cm_); + note.WaitForNotification(); + cm_.StartCancel(); + EXPECT_TRUE(cons_callback_called); + EXPECT_TRUE(prod_callback_called); + TF_EXPECT_OK(cons_status); + TF_EXPECT_OK(prod_status); +} + +TEST_F(BufRendezvousTest, CancelThenProvideConsume) { + Status prod_status; + Status cons_status; + bool prod_callback_called = false; + bool cons_callback_called = false; + cm_.StartCancel(); + br_->ProvideBuf( + *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, + [&prod_status, &prod_callback_called](const Status& s) { + prod_status = s; + EXPECT_TRUE(errors::IsCancelled(prod_status)); + prod_callback_called = true; + }, + &cm_); + EXPECT_TRUE(prod_callback_called); + EXPECT_TRUE(errors::IsCancelled(prod_status)); + br_->ConsumeBuf( + *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, + [&cons_status, &cons_callback_called](const Status& s, + BufRendezvous::Hook* h) { + cons_status = s; + EXPECT_TRUE(errors::IsCancelled(cons_status)); + cons_callback_called = true; + }, + &cm_); + EXPECT_TRUE(cons_callback_called); + EXPECT_TRUE(errors::IsCancelled(cons_status)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 9c46314af67..01b89494c0d 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -779,7 +779,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) { { mutex_lock l(status_mu_); if (!status_.ok()) { - VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring " + VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring " "subsequent abortion with status: " << s; return; diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc index b958a25c091..44175a042a7 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.cc +++ b/tensorflow/core/common_runtime/collective_rma_local.cc @@ -28,7 +28,7 @@ void CollectiveRemoteAccessLocal::RecvFromPeer( const string& key, Device* to_device, DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, - const StatusCallback& done) { + CancellationManager* cancellation_manager, const StatusCallback& done) { VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key " << key; if (!peer_is_local) { @@ -91,21 +91,23 @@ void CollectiveRemoteAccessLocal::RecvFromPeer( }; buf_rendezvous_.ConsumeBuf(key, from_device->name(), from_device->attributes().incarnation(), - consumer_callback); + consumer_callback, cancellation_manager); } void CollectiveRemoteAccessLocal::PostToPeer( const string& peer_device, const string& peer_task, const string& key, Device* from_device, DeviceContext* from_device_ctx, const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, - const DeviceLocality& client_locality, const StatusCallback& done) { + const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) { VLOG(1) << "PostToPeer " << this << " key " << key << " step_id_=" << step_id_; buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor, - from_alloc_attr, done); + from_alloc_attr, done, cancellation_manager); } void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task, + int64 timeout_in_ms, const StatusCallback& done) { // Assume local devices are always healthy. done(errors::Internal( diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h index 12aca901054..fb4ddf178e5 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.h +++ b/tensorflow/core/common_runtime/collective_rma_local.h @@ -43,6 +43,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override; void PostToPeer(const string& peer_device, const string& peer_task, @@ -51,9 +52,10 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) override; - void CheckPeerHealth(const string& peer_task, + void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) override; BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc index 2c606147f7d..30f6e372606 100644 --- a/tensorflow/core/common_runtime/collective_rma_local_test.cc +++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc @@ -52,6 +52,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test { cp, device_mgr_.get(), drl_.get(), kTaskName); rma_ = absl::make_unique(device_mgr_.get(), drl_.get(), kStepId); + cm_ = absl::make_unique(); } ~CollectiveRemoteAccessLocalTest() override = default; @@ -61,6 +62,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test { std::unique_ptr drl_; std::unique_ptr prl_; std::unique_ptr rma_; + std::unique_ptr cm_; }; TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) { @@ -74,7 +76,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) { rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, - 0 /*stream_index*/, + 0 /*stream_index*/, cm_.get(), [&recv_note, &recv_status](const Status& s) { recv_status = s; recv_note.Notify(); @@ -90,7 +92,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) { rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", cpu0 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - [&send_note, &send_status](const Status& s) { + cm_.get(), [&send_note, &send_status](const Status& s) { send_status = s; send_note.Notify(); }); @@ -117,7 +119,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/, "key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, - 0 /*stream_index*/, + 0 /*stream_index*/, cm_.get(), [&recv_note, &recv_status](const Status& s) { recv_status = s; recv_note.Notify(); @@ -135,7 +137,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0", cpu1 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - [&send_note, &send_status](const Status& s) { + cm_.get(), [&send_note, &send_status](const Status& s) { send_status = s; send_note.Notify(); }); @@ -154,13 +156,100 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) { Status status; Notification done; - rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); + rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0, + [&status, &done](const Status& s) { + status = s; + done.Notify(); + }); done.WaitForNotification(); EXPECT_TRUE(errors::IsInternal(status)); } +TEST_F(CollectiveRemoteAccessLocalTest, RecvThenCancel) { + Device* cpu0 = nullptr; + AllocatorAttributes attr; + DeviceLocality dev_locality; + TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); + Tensor sink_tensor(DT_FLOAT, TensorShape({8})); + Notification recv_note; + Status recv_status; + rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, + "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, + attr /*to_alloc_attr*/, &sink_tensor, dev_locality, + 0 /*stream_index*/, cm_.get(), + [&recv_note, &recv_status](const Status& s) { + recv_status = s; + recv_note.Notify(); + }); + cm_->StartCancel(); + recv_note.WaitForNotification(); + EXPECT_TRUE(cm_->IsCancelled()); + EXPECT_TRUE(errors::IsCancelled(recv_status)); +} + +TEST_F(CollectiveRemoteAccessLocalTest, CancelThenRecv) { + Device* cpu0 = nullptr; + AllocatorAttributes attr; + DeviceLocality dev_locality; + TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); + Tensor sink_tensor(DT_FLOAT, TensorShape({8})); + Notification recv_note; + Status recv_status; + cm_->StartCancel(); + rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, + "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, + attr /*to_alloc_attr*/, &sink_tensor, dev_locality, + 0 /*stream_index*/, cm_.get(), + [&recv_note, &recv_status](const Status& s) { + recv_status = s; + recv_note.Notify(); + }); + recv_note.WaitForNotification(); + EXPECT_TRUE(cm_->IsCancelled()); + EXPECT_TRUE(errors::IsCancelled(recv_status)); +} + +TEST_F(CollectiveRemoteAccessLocalTest, PostThenCancel) { + Device* cpu0 = nullptr; + AllocatorAttributes attr; + DeviceLocality dev_locality; + TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); + Tensor source_tensor(DT_FLOAT, TensorShape({8})); + Notification send_note; + Status send_status; + rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", + cpu0 /*from_device*/, nullptr /*from_device_ctx*/, + attr /*to_alloc_attr*/, &source_tensor, dev_locality, + cm_.get(), [&send_note, &send_status](const Status& s) { + send_status = s; + send_note.Notify(); + }); + cm_->StartCancel(); + send_note.WaitForNotification(); + EXPECT_TRUE(cm_->IsCancelled()); + EXPECT_TRUE(errors::IsCancelled(send_status)); +} + +TEST_F(CollectiveRemoteAccessLocalTest, CancelThenPost) { + Device* cpu0 = nullptr; + AllocatorAttributes attr; + DeviceLocality dev_locality; + TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); + Tensor source_tensor(DT_FLOAT, TensorShape({8})); + Notification send_note; + Status send_status; + cm_->StartCancel(); + rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", + cpu0 /*from_device*/, nullptr /*from_device_ctx*/, + attr /*to_alloc_attr*/, &source_tensor, dev_locality, + cm_.get(), [&send_note, &send_status](const Status& s) { + send_status = s; + send_note.Notify(); + }); + send_note.WaitForNotification(); + EXPECT_TRUE(cm_->IsCancelled()); + EXPECT_TRUE(errors::IsCancelled(send_status)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index f87efb369ed..384ec836cdf 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -219,6 +219,7 @@ bool IsConstantFoldable( const std::unordered_map>* shape_map, const std::function& consider, + int64 max_constant_size_in_bytes, std::unordered_map>* shape_replacement_map) { if (n->IsConstant()) { @@ -233,6 +234,20 @@ bool IsConstantFoldable( if (consider && !consider(n)) { return false; } + if (shape_map != nullptr) { + // We can skip the node if an output is known to be oversized. + auto shape_it = shape_map->find(n->name()); + if (shape_it != shape_map->end()) { + for (int64 i = 0; i < shape_it->second.size(); ++i) { + const auto& out_shape = shape_it->second[i]; + if (out_shape.IsFullyDefined() && + out_shape.num_elements() * DataTypeSize(n->output_type(i)) > + max_constant_size_in_bytes) { + return false; + } + } + } + } if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) { return false; } @@ -280,6 +295,7 @@ void ConsiderConstantFoldableNode( std::unordered_map>* shape_replacement_map, bool* internal_node_inserted) { if (IsConstantFoldable(n, opts.shape_map, opts.consider, + opts.max_constant_size_in_bytes, shape_replacement_map)) { // A node is constant provided all of its non-control incoming Tensors come // from constant nodes, or it's a shape Op with statically known inputs in diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 5102681934e..572615bb3f8 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -172,10 +172,11 @@ Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, /* mirror= */ true, &result_handle); activity.Stop(); if (!status.ok()) { - return errors::Internal("Failed copying input tensor from ", - handle_device->name(), " to ", - expected_input_device->name(), " in order to run ", - op->Name(), ": ", status.error_message()); + return Status( + status.code(), + absl::StrCat("Failed copying input tensor from ", handle_device->name(), + " to ", expected_input_device->name(), " in order to run ", + op->Name(), ": ", status.error_message())); } *result = result_handle; diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc index 31d5e05462c..9f5eb90ab64 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc @@ -54,6 +54,9 @@ class MklEagerOpRewrite : public EagerOpRewrite { // Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter. static bool RewriteConv2D(EagerOperation* op); + // Rewrite rule for FusedBatchNormV3 and FusedBatchNormGradV3 + static bool RewriteFusedBatchNormV3(EagerOperation* op); + // Calls op-specific rewrite function to create new MKL op. Status RewriteToMklOp(EagerOperation* orig_op, std::unique_ptr* mkl_op); @@ -110,9 +113,10 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line) InsertMKLEagerOps( {"FusedBatchNormGradV2", AlwaysRewrite, CreateGenericMklOp}); InsertMKLEagerOps( - {"FusedBatchNormGradV3", AlwaysRewrite, CreateGenericMklOp}); + {"FusedBatchNormGradV3", RewriteFusedBatchNormV3, CreateGenericMklOp}); InsertMKLEagerOps({"FusedBatchNormV2", AlwaysRewrite, CreateGenericMklOp}); - InsertMKLEagerOps({"FusedBatchNormV3", AlwaysRewrite, CreateGenericMklOp}); + InsertMKLEagerOps( + {"FusedBatchNormV3", RewriteFusedBatchNormV3, CreateGenericMklOp}); InsertMKLEagerOps({"MatMul", AlwaysRewrite, CreateGenericMklOp}); }; @@ -246,5 +250,15 @@ bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) { return (padding != "EXPLICIT"); } +bool MklEagerOpRewrite::RewriteFusedBatchNormV3(EagerOperation* op) { + const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); + if (Check5DFormat(ndef)) { + VLOG(1) << "Eager Op Rewrite: FusedBatchNorm(Grad)V3 op currently does not " + << "support 5D tensors."; + return false; + } + return true; +} + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc index a18301231cf..b56d97428b3 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc @@ -130,6 +130,26 @@ REGISTER_TEST_ALL_TYPES(ConvOpsExplicitPadding_Negative); REGISTER_TEST_ALL_TYPES(MostOps_Positive); #undef REGISTER_TEST +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(EagerOpRewriteTest, NAME##_##T) { \ + std::vector Fused_BN_ops = {"FusedBatchNormV3", \ + "FusedBatchNormGradV3"}; \ + for (int i = 0; i < Fused_BN_ops.size(); ++i) { \ + auto orig_op = CreateOp(Fused_BN_ops[i]); \ + orig_op->MutableAttrs()->Set("T", T); \ + orig_op->MutableAttrs()->Set("data_format", "" DATA_FORMAT ""); \ + CheckRewrite(orig_op.get(), Fused_BN_ops[i]); \ + } \ + } +#define DATA_FORMAT "NCDHW" +REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1); + +#define DATA_FORMAT "NDHWC" +REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 0b17bf7e3e6..da37ad1b480 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -1116,6 +1116,28 @@ const char* TensorHandle::BackingDeviceName(Status* status) const { } } +const char* TensorHandle::DeviceType(Status* status) const { + if (VariantDeviceIsCustom(device())) { + status->Update( + tensorflow::errors::Unimplemented("Custom device unsupported")); + return nullptr; + } + status->Update(WaitUnknownDevice()); + tensorflow::Device* d = op_device(); + return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str(); +} + +int TensorHandle::DeviceId(Status* status) const { + if (VariantDeviceIsCustom(device())) { + status->Update( + tensorflow::errors::Unimplemented("Custom device unsupported")); + return -1; + } + status->Update(WaitUnknownDevice()); + tensorflow::Device* d = op_device(); + return (d == nullptr) ? 0 : d->parsed_name().id; +} + tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() { Ref(); return this; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index f54ebd45889..b2bb24f5bc0 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -131,6 +131,8 @@ class TensorHandle : public ImmediateExecutionTensorHandle { const char* DeviceName(Status* status) const override; const char* BackingDeviceName(Status* status) const override; + const char* DeviceType(Status* status) const override; + int DeviceId(Status* status) const override; AbstractTensorInterface* Resolve(Status* status) override; ImmediateExecutionTensorHandle* Copy() override; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 286bf8775ce..715e7f48ef5 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -408,4 +408,63 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) { context->Unref(); } +TEST(TensorHandle_DeviceNameTest, OnLocalDevice) { + std::vector> devices; + devices.emplace_back( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + devices.emplace_back( + CreateDevice("GPU", "/job:localhost/replica:0/task:0/device:GPU:0")); + StaticDeviceMgr local_device_mgr(std::move(devices)); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false, + false, &local_device_mgr, false, nullptr, nullptr); + + Device* dcpu = local_device_mgr.ListDevices()[0]; + Device* dgpu = local_device_mgr.ListDevices()[1]; + tensorflow::DataType dtype = DT_RESOURCE; + TensorShape shape = {2}; + Tensor tcpu(dtype, shape); + Tensor tgpu(dtype, shape); + Status s; + + TensorHandle* th_cpu = + TensorHandle::CreateLocalHandle(std::move(tcpu), dcpu, dcpu, dcpu, ctx); + const char* device_name = th_cpu->DeviceName(&s); + TF_EXPECT_OK(s); + ASSERT_TRUE(absl::StrContains(device_name, "CPU")) << device_name; + const char* backing_device_name = th_cpu->BackingDeviceName(&s); + TF_EXPECT_OK(s); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU")) + << backing_device_name; + const char* device_type = th_cpu->DeviceType(&s); + TF_EXPECT_OK(s); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = th_cpu->DeviceId(&s); + TF_EXPECT_OK(s); + ASSERT_EQ(0, device_id) << device_id; + + TensorHandle* th_gpu = + TensorHandle::CreateLocalHandle(std::move(tgpu), dgpu, dgpu, dgpu, ctx); + device_name = th_gpu->DeviceName(&s); + TF_EXPECT_OK(s); + ASSERT_TRUE(absl::StrContains(device_name, "GPU")) << device_name; + backing_device_name = th_gpu->BackingDeviceName(&s); + TF_EXPECT_OK(s); + std::cout << "backing_device_name for GPU: " << backing_device_name + << std::endl; + ASSERT_TRUE(absl::StrContains(backing_device_name, "GPU")) + << backing_device_name; + device_type = th_gpu->DeviceType(&s); + TF_EXPECT_OK(s); + ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type; + device_id = th_gpu->DeviceId(&s); + TF_EXPECT_OK(s); + ASSERT_EQ(0, device_id) << device_id; + + th_cpu->Unref(); + th_gpu->Unref(); + ctx->Unref(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc index cd286edabf9..35362982df5 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor.cc +++ b/tensorflow/core/common_runtime/eval_const_tensor.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" @@ -123,6 +124,17 @@ bool HasCpuKernel(const Node& node) { .ok(); } +Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) { + DCHECK(node->IsArg()); + TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index)); + if (*index < 0 || num_function_inputs <= *index) { + return errors::Internal( + "Function instantiation included invalid input index: ", index, + " not in [0, ", num_function_inputs, ")."); + } + return Status::OK(); +} + // Extracts the subgraph ending at 'target_node' that is statically computable // and inserts into 'out_graph'. If statically computable, 'is_constant_graph' // will be set to true. @@ -130,7 +142,8 @@ Status ExtractConstantSubgraph( const Node& target_node, const ShapeRefiner& refiner, const std::unordered_map* cached_values, Graph* out_graph, bool* is_constant_graph, - std::vector>* const_inputs) { + std::vector>* const_inputs, + InferenceContext* outer_context) { *is_constant_graph = false; std::unordered_set const_inputs_added; @@ -187,8 +200,9 @@ Status ExtractConstantSubgraph( edges_to_visit.pop_front(); Node* current_node = current_edge->src(); - // If the node is stateful, assume the graph is not constant. - if (current_node->op_def().is_stateful()) { + // If the node is stateful, assume the graph is not constant unless it is + // an Arg node which is handled later on. + if (!current_node->IsArg() && current_node->op_def().is_stateful()) { *is_constant_graph = false; return Status::OK(); } @@ -223,9 +237,32 @@ Status ExtractConstantSubgraph( } // If there is nothing more to recurse down, see if - // the generator node is a constant. + // the generator node is a constant or an Arg node whose value is available + // in the `outer_context`. if (current_node->num_inputs() == 0) { - if (!current_node->IsConstant()) { + if (outer_context && current_node->IsArg()) { + const string& tensor_name = + strings::StrCat(current_node->name(), ":", 0); + // If we do not already have a constant Tensor for this Arg try to + // fetch it from the outer context. + if (const_inputs_added.count(tensor_name) == 0) { + int index; + TF_RETURN_IF_ERROR(GetArgNodeIndex( + current_node, outer_context->num_inputs(), &index)); + const Tensor* const_tensor = outer_context->input_tensor(index); + if (const_tensor) { + const_inputs->emplace_back(tensor_name, *const_tensor); + const_inputs_added.insert(tensor_name); + } else { + // Request a constant value for this Arg. If that is statically + // computable, shape refiner will re-run the shape inference for + // this function with this tensor's value. + outer_context->request_input_tensor(index); + *is_constant_graph = false; + return Status::OK(); + } + } + } else if (!current_node->IsConstant()) { // Generator node is not a constant, so subgraph is not // constant. *is_constant_graph = false; @@ -314,7 +351,8 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner, Tensor* result, GraphRunner* graph_runner, std::unordered_map* cached_values, int64 max_cached_value_size, - bool disable_constant_propagation) { + bool disable_constant_propagation, + InferenceContext* outer_context) { *evaluated = false; const Node* src = tensor.node; @@ -326,6 +364,22 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner, } } + // If the source node is an Arg return its value, if available in the outer + // context. + if (src->IsArg() && outer_context) { + int index; + TF_RETURN_IF_ERROR( + GetArgNodeIndex(src, outer_context->num_inputs(), &index)); + const Tensor* const_tensor = outer_context->input_tensor(index); + if (const_tensor) { + *evaluated = true; + *result = *(outer_context->input_tensor(index)); + } else { + outer_context->request_input_tensor(index); + } + return Status::OK(); + } + if (disable_constant_propagation) { return Status::OK(); } @@ -339,7 +393,7 @@ Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner, std::vector> const_inputs; TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values, &subgraph, &is_constant_graph, - &const_inputs)); + &const_inputs, outer_context)); if (!is_constant_graph) { return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eval_const_tensor.h b/tensorflow/core/common_runtime/eval_const_tensor.h index fca5a235695..b63d492f657 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor.h +++ b/tensorflow/core/common_runtime/eval_const_tensor.h @@ -53,13 +53,17 @@ class Tensor; // result size to cache. // disable_constant_propagation - if true, only Const node values will be // returned. +// outer_context - optional. The InferenceContext for the call node if inside +// a nested function. This is useful for doing constant propagation across +// Arg nodes. Status EvaluateConstantTensor( OutputTensor tensor, const ShapeRefiner& refiner, const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated, Tensor* result, GraphRunner* graph_runner = nullptr, std::unordered_map* cached_values = nullptr, int64 max_cached_value_size = 1024, - bool disable_constant_propagation = false); + bool disable_constant_propagation = false, + shape_inference::InferenceContext* outer_context = nullptr); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index cd3e73a30b9..03c23f32880 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1119,11 +1119,13 @@ bool ExecutorState::NodeDone( if (rendezvous_) { rendezvous_->StartAbort(s); } - if (collective_executor_) { - collective_executor_->StartAbort(s); - } if (cancellation_manager_) { cancellation_manager_->StartCancel(); + } else if (collective_executor_) { + // If there's cancellation_manager_, collective ops aborts + // collective_executor_ upon cancellation; otherwise we need to abort + // here. + collective_executor_->StartAbort(s); } } @@ -1267,11 +1269,13 @@ void ExecutorState::Finish() { if (rendezvous_) { rendezvous_->StartAbort(status); } - if (collective_executor_) { - collective_executor_->StartAbort(status); - } if (cancellation_manager_) { cancellation_manager_->StartCancel(); + } else if (collective_executor_) { + // If there's cancellation_manager_, collective ops aborts + // collective_executor_ upon cancellation; otherwise we need to abort + // here. + collective_executor_->StartAbort(status); } } delete this; diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index e1d81a18464..e78fbef13de 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -423,7 +423,8 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, col_params_->group.task_names[dst_idx], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), src_tensor, - col_ctx_->device_locality, done); + col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(), + done); } void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank, @@ -443,7 +444,8 @@ void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank, col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, - col_ctx_->device_locality, 0 /*stream_index*/, done); + col_ctx_->device_locality, 0 /*stream_index*/, + col_ctx_->op_ctx->cancellation_manager(), done); } namespace { diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index 112e6c9881c..97a1d0b46ce 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h" #include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -165,11 +166,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, - to_alloc_attr, to_tensor, client_locality, stream_index, done); + to_alloc_attr, to_tensor, client_locality, stream_index, + cancellation_manager, done); } void PostToPeer(const string& peer_device, const string& peer_task, @@ -178,11 +181,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::PostToPeer( peer_device, peer_task, key, from_device, from_device_ctx, - from_alloc_attr, from_tensor, client_locality, done); + from_alloc_attr, from_tensor, client_locality, cancellation_manager, + done); } mutex mu_; @@ -618,6 +623,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { OpKernelContext::Params op_params; op_params.step_id = parent_->step_id_; op_params.device = device_; + op_params.cancellation_manager = &parent_->cancellation_manager_; gtl::InlinedVector inputs; inputs.push_back(TensorValue(&tensor_)); op_params.inputs = &inputs; @@ -710,6 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { int bcast_recv_counter_ TF_GUARDED_BY(mu_) = 0; int bcast_send_counter_ TF_GUARDED_BY(mu_) = 0; int failure_count_ TF_GUARDED_BY(mu_) = 0; + CancellationManager cancellation_manager_; }; TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) { diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 0b6edf74daf..846616b2c02 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -48,7 +48,8 @@ namespace test { // TODO(hongm): Convert `g` and `init` to using std::unique_ptr. Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init, - Rendezvous* rendez, const char* executor_type) { + Rendezvous* rendez, const char* executor_type, + bool old_benchmark_api) { auto cleanup = gtl::MakeCleanup([g, init]() { delete g; delete init; @@ -59,7 +60,8 @@ Benchmark::Benchmark(const string& device, Graph* g, options = &default_options; } - testing::StopTiming(); + old_benchmark_api_ = old_benchmark_api; + if (old_benchmark_api_) testing::StopTiming(); string t = absl::AsciiStrToUpper(device); // Allow NewDevice to allocate a new threadpool with different number of // threads for each new benchmark. @@ -135,6 +137,10 @@ Benchmark::~Benchmark() { void Benchmark::Run(int iters) { RunWithRendezvousArgs({}, {}, iters); } +void Benchmark::Run(::testing::benchmark::State& state) { + RunWithRendezvousArgs({}, {}, state); +} + string GetRendezvousKey(const Node* node) { string send_device; TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device)); @@ -149,9 +155,63 @@ string GetRendezvousKey(const Node* node) { recv_device, tensor_name, FrameAndIter(0, 0)); } +void Benchmark::RunWithRendezvousArgs( + const std::vector>& inputs, + const std::vector& outputs, ::testing::benchmark::State& state) { + CHECK(!old_benchmark_api_) + << "This method should only be called with new benchmark API"; + if (!device_ || state.max_iterations == 0) { + return; + } + Tensor unused; // In benchmark, we don't care the return value. + bool is_dead; + + // Warm up + Executor::Args args; + args.rendezvous = rendez_; + args.runner = [this](std::function closure) { + pool_->Schedule(closure); + }; + static const int kWarmupRuns = 3; + for (int i = 0; i < kWarmupRuns; ++i) { + for (const auto& p : inputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false)); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : outputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead)); + } + } + TF_CHECK_OK(device_->Sync()); + VLOG(3) << kWarmupRuns << " warmup runs done."; + + // Benchmark loop. Timer starts automatically at the beginning of the loop + // and ends automatically after the last iteration. + for (auto s : state) { + for (const auto& p : inputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false)); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : outputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead)); + } + } + TF_CHECK_OK(device_->Sync()); +} + void Benchmark::RunWithRendezvousArgs( const std::vector>& inputs, const std::vector& outputs, int iters) { + CHECK(old_benchmark_api_) << "This method should only be called when running " + "with old benchmark API"; if (!device_ || iters == 0) { return; } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h index 9c6b1eb088c..fe161b6b939 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -26,6 +26,12 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +namespace testing { +namespace benchmark { +class State; +} // namespace benchmark +} // namespace testing + namespace tensorflow { class Device; @@ -40,23 +46,42 @@ class Benchmark { public: // "device" must be either "cpu" or "gpu". Takes ownership of "g", // "init", and one reference on "rendez" (if not null). + // + // old_benchmark_api: If true, the benchmark is running with older API + // * In the old API, the timer needs to be stopped/restarted + // by users. + // * In the new API, the timer starts automatically at the first + // iteration of the loop and stops after the last iteration. + // TODO(vyng) Remove this once we have migrated all code to newer API. Benchmark(const string& device, Graph* g, const SessionOptions* options = nullptr, Graph* init = nullptr, - Rendezvous* rendez = nullptr, const char* executor_type = ""); + Rendezvous* rendez = nullptr, const char* executor_type = "", + bool old_benchmark_api = true); ~Benchmark(); // Executes the graph for "iters" times. + // This function is deprecated. Use the overload that takes + // `benchmark::State&` + // instead. void Run(int iters); + void Run(::testing::benchmark::State& state); + // If "g" contains send/recv nodes, before each execution, we send // inputs to the corresponding recv keys in the graph, after each // execution, we recv outputs from the corresponding send keys in // the graph. In the benchmark, we throw away values returned by the // graph. + // This function is deprecated. Use the overload that takes + // `benchmark::State&` instead. void RunWithRendezvousArgs( const std::vector>& inputs, const std::vector& outputs, int iters); + void RunWithRendezvousArgs( + const std::vector>& inputs, + const std::vector& outputs, ::testing::benchmark::State& state); + private: thread::ThreadPool* pool_ = nullptr; // Not owned. Device* device_ = nullptr; // Not owned. @@ -66,6 +91,7 @@ class Benchmark { std::unique_ptr pflr_; FunctionLibraryRuntime* flr_; // Not owned. std::unique_ptr exec_; + bool old_benchmark_api_; TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); }; diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index ff010ad8a63..2a0e5d35de5 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, Status CondBuilder::CreatePivotNodes() { // Construct the basic cond body (consisting of feeding in the predicate to // create pivot nodes). + + // This is a special pivot switch node for lowering. We mark this with a + // special _PivotSwitch attr on it as later on in the graph partitioner we + // do some special placement for Switch nodes and its necessary to distinguish + // between a "normal" Switch node and one of these pivot switches. We would + // like to place this node on the CPU always as the pred_ will be on the CPU + // as well (either a CPU op output or a GPU op with HostMemory annotation). + // TODO(b/171321391): Fix this for NUMA cases. Node* switch_pred; TF_RETURN_IF_ERROR( SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry(), &debug_info_) .Input(NodeOut(pred_)) .Input(NodeOut(pred_)) - .Device(if_op_->requested_device()), + .Attr("_PivotSwitch", true) + .Device("/CPU:0"), graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index cf7d35409bb..b0304cfe293 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) { } } +TEST(LowerIfOpTest, GPUPlacement) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + // Add test functions for then and else branch. + FunctionDefLibrary f_lib_proto; + *(f_lib_proto.add_function()) = test::function::XTimesTwo(); + *(f_lib_proto.add_function()) = test::function::XTimesFour(); + + // Construct simple conditional that switches on `pred` and operates only on + // single input `A`. + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32); + auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32); + auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32); + Node* pred; + TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def()) + .Input(x.node()) + .Input(y.node()) + .Device("/GPU:0") + .Finalize(root.graph(), &pred)); + Node* written_if; + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + TF_ASSERT_OK( + NodeBuilder("if", "If", &root.graph()->flib_def()) + .Input(pred) + .Input(inputs) + .Attr("then_branch", FuncAttr("XTimesTwo")) + .Attr("else_branch", FuncAttr("XTimesFour")) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Attr("Tout", {DT_INT32}) + .Device("/GPU:0") + .Finalize(root.graph(), &written_if)); + TF_ASSERT_OK(root.DoShapeInference(written_if)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + // The input graph has no switch or merge nodes. + int node_called_if_count = 0; + for (const auto* op : graph->op_nodes()) { + ASSERT_FALSE(op->IsSwitch()); + ASSERT_FALSE(op->IsMerge()); + if (op->name() == "if") { + ++node_called_if_count; + } + } + ASSERT_EQ(node_called_if_count, 1); + + TF_ASSERT_OK(Rewrite(&graph)); + + // Verify the resultant graph has switch and merge nodes, and a node called + // `if` (but not If nodes). + int switch_count = 0; + int merge_count = 0; + node_called_if_count = 0; + for (const auto* op : graph->op_nodes()) { + if (op->IsSwitch()) { + ++switch_count; + } + if (op->IsMerge()) { + ++merge_count; + } + ASSERT_NE(op->type_string(), "If"); + if (op->name() == "if") { + ++node_called_if_count; + } + } + // One switch for predicate and one for input (A). + ASSERT_EQ(switch_count, 2); + // One merge for the single output value of then and else, and one more merge + // to enforce then and else function call execution (`branch_executed` node). + ASSERT_EQ(merge_count, 2); + ASSERT_EQ(node_called_if_count, 1); + + // Verify execution. + ClientSession session(root, SessionOptionsWithInlining()); + { + RunMetadata metadata; + RunOptions options; + options.set_output_partition_graphs(true); + ClientSession::FeedType feeds; + feeds.emplace(Output(x.node()), Input::Initializer(5)); + feeds.emplace(Output(y.node()), Input::Initializer(10)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {}, + &out_tensors, &metadata)); + GraphDef cpu_graph = metadata.partition_graphs(1); + int num_cpu_switch = 0; + for (const auto& node : cpu_graph.node()) { + if (node.op() == "Switch") { + ++num_cpu_switch; + } + } + EXPECT_EQ(num_cpu_switch, 2); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 40); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(x.node()), Input::Initializer(10)); + feeds.emplace(Output(y.node()), Input::Initializer(5)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } +} + TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) { using ::tensorflow::test::function::GDef; using ::tensorflow::test::function::NDef; diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator.cc index 4ec85457add..43a909466ed 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.cc +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.cc @@ -17,13 +17,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h" -#ifdef _WIN32 -// Declare function to avoid unresolved symbol in VS -i_malloc_t i_malloc; -i_calloc_t i_calloc; -i_realloc_t i_realloc; -i_free_t i_free; -#endif namespace tensorflow { constexpr const char* MklCPUAllocator::kMaxLimitStr; diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 99c41e4c75e..1686b107c98 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -475,11 +475,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.fused_batch_norm_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); rinfo_.push_back( {csinfo_.fused_batch_norm_grad_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); #ifdef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.fused_batch_norm_ex, native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex @@ -531,7 +531,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.max_pool3d_grad, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()}); rinfo_.push_back( {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); @@ -1121,7 +1121,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd // node that can be merged with 'm'. static Node* GetConv2DOrBiasAdd(const Node* m) { - CHECK_NOTNULL(m); + DCHECK(m); Node* n = nullptr; DataType T_m; @@ -1288,7 +1288,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // So 1st input of BiasAddGrad connects with 3rd input of // Conv2DBackpropFilter and vice versa. static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) { - CHECK_NOTNULL(m); + DCHECK(m); Node* n = nullptr; DataType T_m; @@ -1548,7 +1548,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @return - true (if it is not a depth/batch wise pooling case); // false otherwise. static bool NonDepthBatchWisePoolRewrite(const Node* n) { - CHECK_NOTNULL(n); + DCHECK(n); string data_format_str; TensorFormat data_format; @@ -1575,7 +1575,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // and use default Eigen. But for depth_radius=2, MKL DNN optimized // path is taken, i.e., eigen node is rewritten by MKl DNN node. static bool LrnRewrite(const Node* n) { - CHECK_NOTNULL(n); + DCHECK(n); int depth_radius; TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius)); @@ -1593,7 +1593,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } static bool LrnGradRewrite(const Node* n) { - CHECK_NOTNULL(n); + DCHECK(n); bool do_rewrite = false; for (const Edge* e : n->in_edges()) { @@ -1687,8 +1687,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } return true; } + static bool MaxpoolGradRewrite(const Node* n) { - CHECK_NOTNULL(n); + DCHECK(n); bool do_rewrite = false; for (const Edge* e : n->in_edges()) { // Rewrite only if there is corresponding Maxpool, i.e workspace is @@ -1705,6 +1706,32 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return do_rewrite; } + static bool Maxpool3DGradRewrite(const Node* n) { + DCHECK(n); + for (const Edge* e : n->in_edges()) { + // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is + // available + if (e->dst()->type_string() == csinfo_.max_pool3d_grad && + e->dst_input() == 1 && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) && + e->src_output() == 0) { + return true; + } + } + return false; + } + + static bool FusedBatchNormV3Rewrite(const Node* n) { + DCHECK(n); + if (Check5DFormat(n->def())) { + VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not " + << "support 5D tensors."; + return false; + } + return true; + } + static bool FusedBatchNormExRewrite(const Node* n) { DCHECK(n); @@ -2058,7 +2085,7 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList( int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); - CHECK_NOTNULL(output_nodes); + DCHECK(output_nodes); output_nodes->reserve(list_length); while (list_length != 0) { @@ -2095,7 +2122,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, // device of the original // node. .Finalize(&**g, out)); - CHECK_NOTNULL(*out); // Make sure we got a valid object before using it + DCHECK(*out); // Make sure we got a valid object before using it // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -2123,7 +2150,7 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); - CHECK_NOTNULL(output_nodes); + DCHECK(output_nodes); output_nodes->reserve(list_length); while (list_length != 0) { @@ -2151,9 +2178,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( void MklLayoutRewritePass::GetNodeProducingMklTensor( std::unique_ptr* g, const Node* orig_node, Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { - CHECK_NOTNULL(n); - CHECK_NOTNULL(mkl_node); - CHECK_NOTNULL(mkl_node_output_slot); + DCHECK(n); + DCHECK(mkl_node); + DCHECK(mkl_node_output_slot); // If this is an MKL op, then it will create extra output for MKL layout. DataType T; @@ -2172,7 +2199,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor( // DummyMklTensor node has no input and generates only 1 output // (dummy Mkl tensor) as output slot number 0. GetDummyMklTensorNode(g, mkl_node, orig_node); - CHECK_NOTNULL(*mkl_node); + DCHECK(*mkl_node); *mkl_node_output_slot = 0; } } @@ -2183,7 +2210,7 @@ int MklLayoutRewritePass::SetUpContiguousInputs( NodeBuilder* nb, const Node* old_node, std::vector* workspace_tensors, bool are_workspace_tensors_available) { - CHECK_NOTNULL(workspace_tensors); + DCHECK(workspace_tensors); CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); // TODO(nhasabni): Temporary solution to connect filter input of @@ -2201,7 +2228,7 @@ int MklLayoutRewritePass::SetUpContiguousInputs( Node* filter_node = nullptr; TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node)); - CHECK_NOTNULL(filter_node); + DCHECK(filter_node); // Now check which nodes receive from filter_node. Filter feeds as // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and @@ -2451,7 +2478,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( std::unique_ptr* g, const Node* orig_node, NodeBuilder* nb, std::vector* ws_tensors, bool* are_ws_tensors_added) { bool workspace_edge_added = false; // Default initializer - CHECK_NOTNULL(are_ws_tensors_added); + DCHECK(are_ws_tensors_added); *are_ws_tensors_added = false; // Default initializer DataType T; @@ -2506,7 +2533,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); - CHECK_NOTNULL(ws_tensors); + DCHECK(ws_tensors); // Add workspace edge between fwd op and bwd op. ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); // Check if we are running in native format mode. If so, @@ -2540,9 +2567,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); - CHECK_NOTNULL(dmt_ws); - CHECK_NOTNULL(dmt_mkl_ws); - CHECK_NOTNULL(ws_tensors); + DCHECK(dmt_ws); + DCHECK(dmt_mkl_ws); + DCHECK(ws_tensors); // We add dummy tensor as workspace tensor. ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); // We add dummy tensor as Mkl tensor for workspace tensor. @@ -3204,8 +3231,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 // output (at slot 0). const int kConv2DWithBiasOutputSlot = 0; - CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), - e->dst_input())); + auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, + e->dst(), e->dst_input()); + DCHECK(new_edge); } } @@ -3498,8 +3526,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( (*g)->AddControlEdge(new_node, e->dst(), true); } } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, - e->dst(), e->dst_input())); + auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, + e->dst(), e->dst_input()); + DCHECK(new_edge); } } unique_node.clear(); @@ -3512,8 +3541,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( (*g)->AddControlEdge(new_node, e->dst(), true); } } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, - e->dst(), e->dst_input())); + auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, + e->dst(), e->dst_input()); + DCHECK(new_edge); } } @@ -3534,8 +3564,8 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* m, Node* n) { - CHECK_NOTNULL(m); - CHECK_NOTNULL(n); + DCHECK(m); + DCHECK(n); if (((m->type_string() == csinfo_.bias_add && n->type_string() == csinfo_.conv2d)) || @@ -3641,10 +3671,11 @@ Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation( (*g)->AddControlEdge(*new_node, e->dst(), true); } } else { - CHECK_NOTNULL((*g)->AddEdge( + auto new_edge = (*g)->AddEdge( *new_node, GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), - e->dst(), e->dst_input())); + e->dst(), e->dst_input()); + DCHECK(new_edge); } } return Status::OK(); @@ -3798,7 +3829,7 @@ MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const { const MklLayoutRewritePass::RewriteInfo* MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { - CHECK_NOTNULL(n); + DCHECK(n); // QuantizedOps may have attributes other than "T", so decoupled the check // with a function, CheckForQuantizedNodeRewrite(const Node*). @@ -4017,8 +4048,9 @@ bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr* g, if (IsConstant(e_metadata->src())) { Node* e_metadata_dst = e_metadata->dst(); int e_metadata_in_slot = e_metadata->dst_input(); - CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst, - e_metadata_in_slot)); + auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst, + e_metadata_in_slot); + DCHECK(new_edge); (*g)->RemoveEdge(e_metadata); return true; @@ -4090,7 +4122,7 @@ bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr* g, bool MklLayoutRewritePass::RunPass(std::unique_ptr* g) { bool result = false; - CHECK_NOTNULL(g); + DCHECK(g); DumpGraph("Before running MklLayoutRewritePass", &**g); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index 4366a7892d3..fda5ad93352 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -3394,6 +3394,37 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) { REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_Positive); #undef REGISTER_TEST +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: 'Float32Input'}" \ + "node { name: 'C' op: 'Float32Input'}" \ + "node { name: 'D' op: 'Float32Input'}" \ + "node { name: 'E' op: 'Float32Input'}" \ + "node { name: 'F' op: 'FusedBatchNormV3'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " input: ['A', 'B', 'C', 'D', 'E'] }" \ + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'F'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Float32Input);C(Float32Input);" \ + "D(Float32Input);E(Float32Input);F(FusedBatchNormV3);G(Zeta)" \ + "|A->F;A->G;B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); \ +} +#define DATA_FORMAT "'NCDHW'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_1); + +#define DATA_FORMAT "'NDHWC'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) { DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( @@ -3417,6 +3448,38 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) { "B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); } +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: 'Float32Input'}" \ + "node { name: 'D' op: 'Float32Input'}" \ + "node { name: 'E' op: 'Float32Input'}" \ + "node { name: 'F' op: 'Float32Input'}" \ + "node { name: 'G' op: 'FusedBatchNormGradV3'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \ + "node { name: 'H' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'G'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \ + "E(Float32Input);F(Float32Input);G(FusedBatchNormGradV3);H(Zeta)" \ + "|A->G;A->H;B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \ +} +#define DATA_FORMAT "'NCDHW'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_1); + +#define DATA_FORMAT "'NDHWC'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + #ifdef ENABLE_MKLDNN_V1 #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ diff --git a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc index c29752d3c2c..8d64f6e69db 100644 --- a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc +++ b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc @@ -16,7 +16,6 @@ limitations under the License. #ifdef INTEL_MKL #include "tensorflow/core/common_runtime/threadpool_device.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" @@ -37,15 +36,6 @@ TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) { EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht); } -TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) { - SessionOptions options; - setenv("OMP_NUM_THREADS", "314", 1); - - ThreadPoolDevice* tp = new ThreadPoolDevice( - options, "/device:CPU:0", Bytes(256), DeviceLocality(), cpu_allocator()); - - EXPECT_EQ(omp_get_max_threads(), 314); -} #endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL) } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc index 1053cae1f7d..9aee5e5d5c9 100644 --- a/tensorflow/core/common_runtime/permuter.cc +++ b/tensorflow/core/common_runtime/permuter.cc @@ -90,7 +90,7 @@ void Permuter::DispatchSend(int src_rank, int target_rank, const Tensor* tensor, col_params_->group.task_names[target_rank], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, - done); + col_ctx_->op_ctx->cancellation_manager(), done); } void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor, @@ -107,7 +107,7 @@ void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor, col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, - 0, done); + 0, col_ctx_->op_ctx->cancellation_manager(), done); } namespace { REGISTER_COLLECTIVE(Permute, Permuter); diff --git a/tensorflow/core/common_runtime/permuter_test.cc b/tensorflow/core/common_runtime/permuter_test.cc index 1b65a9ebe5f..10c527ca573 100644 --- a/tensorflow/core/common_runtime/permuter_test.cc +++ b/tensorflow/core/common_runtime/permuter_test.cc @@ -77,11 +77,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, - to_alloc_attr, to_tensor, client_locality, stream_index, done); + to_alloc_attr, to_tensor, client_locality, stream_index, + cancellation_manager, done); } void PostToPeer(const string& peer_device, const string& peer_task, @@ -90,11 +92,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::PostToPeer( peer_device, peer_task, key, from_device, from_device_ctx, - from_alloc_attr, from_tensor, client_locality, done); + from_alloc_attr, from_tensor, client_locality, cancellation_manager, + done); } mutex mu_; @@ -361,6 +365,7 @@ class PermuterTest : public ::testing::Test { OpKernelContext::Params op_params; op_params.step_id = parent_->step_id_; op_params.device = device_; + op_params.cancellation_manager = &parent_->cancellation_manager_; gtl::InlinedVector inputs; inputs.push_back(TensorValue(&tensor_input_)); op_params.inputs = &inputs; @@ -427,6 +432,7 @@ class PermuterTest : public ::testing::Test { mutex mu_; int permute_counter_ TF_GUARDED_BY(mu_) = 0; std::vector permutation_; + CancellationManager cancellation_manager_; }; // TODO(b/113171733): change to use TEST_P. diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index b7a3bd11ec6..e664eb90865 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -278,12 +278,17 @@ void RingAlg::StartAbort(const Status& s) { status_.Update(s); } } - // If this is the initial entry to abort mode then invoke StartAbort - // on the CollectiveExecutor that invoked us. That should start - // cancellation on all of the outstanding CollectiveRemoteAccess - // actions. + // If this is the initial entry to abort mode and it's not a cancellation, + // then invoke StartAbort on the CollectiveExecutor that invoked us. That + // should start cancellation on all of the outstanding CollectiveRemoteAccess + // actions. If it's cancellation all pending send/recv should be cancelled as + // well and there's then no need to abort. if (abort_started) { - col_ctx_->col_exec->StartAbort(s); + if (col_ctx_->op_ctx->cancellation_manager() == nullptr || + (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() && + !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) { + col_ctx_->col_exec->StartAbort(s); + } } } @@ -389,7 +394,8 @@ void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { col_params_->group.task_names[send_to_dev_idx], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk, - col_ctx_->device_locality, done); + col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(), + done); } void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { @@ -409,7 +415,8 @@ void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, - col_ctx_->device_locality, rf->subdiv_idx, done); + col_ctx_->device_locality, rf->subdiv_idx, + col_ctx_->op_ctx->cancellation_manager(), done); } string RingAlg::FieldState() { diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index 6b51993e2f4..1f23ee1a8a7 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -70,12 +70,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, - done); + cancellation_manager, done); } void PostToPeer(const string& peer_device, const string& peer_task, @@ -84,11 +85,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::PostToPeer( peer_device, peer_task, key, from_device, from_device_ctx, - from_alloc_attr, from_tensor, client_locality, done); + from_alloc_attr, from_tensor, client_locality, cancellation_manager, + done); } mutex mu_; @@ -442,6 +445,7 @@ class RingGathererTest : public ::testing::Test { OpKernelContext::Params op_params; op_params.step_id = kStepId; op_params.device = device_; + op_params.cancellation_manager = &parent_->cancellation_manager_; gtl::InlinedVector inputs; inputs.push_back(TensorValue(&input_tensor_)); op_params.inputs = &inputs; @@ -523,6 +527,7 @@ class RingGathererTest : public ::testing::Test { std::unique_ptr gpu_ring_order_; mutex mu_; int32 gather_counter_ TF_GUARDED_BY(mu_) = 0; + CancellationManager cancellation_manager_; }; CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index 71f0226549f..2a3ca4f275a 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -256,7 +256,7 @@ bool RingReducer::RunAsyncParts() { rf->action = RF_REDUCE; Status s = collective_util::ComputeBinOp( col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, - col_params_->merge_op.get(), &rf->chunk, &rf->tmp_chunk); + col_params_->merge_op, &rf->chunk, &rf->tmp_chunk); if (!s.ok()) { aborted = true; StartAbort(s); @@ -266,13 +266,12 @@ bool RingReducer::RunAsyncParts() { } break; case RF_REDUCE: - if (!rf->second_pass && col_params_->final_op.get() && - rf->is_final) { + if (!rf->second_pass && col_params_->final_op && rf->is_final) { rf->action = RF_FINALIZE; group_size_tensor_ready_.WaitForNotification(); Status s = collective_util::ComputeBinOp( col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, - col_params_->final_op.get(), &rf->chunk, &group_size_tensor_); + col_params_->final_op, &rf->chunk, &group_size_tensor_); if (!s.ok()) { aborted = true; StartAbort(s); diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index ad20a243151..3b153e4ca1d 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h" #include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -70,12 +71,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, - done); + cancellation_manager, done); } void PostToPeer(const string& peer_device, const string& peer_task, @@ -84,11 +86,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) override { if (MaybeFail(done)) return; CollectiveRemoteAccessLocal::PostToPeer( peer_device, peer_task, key, from_device, from_device_ctx, - from_alloc_attr, from_tensor, client_locality, done); + from_alloc_attr, from_tensor, client_locality, cancellation_manager, + done); } mutex mu_; @@ -462,15 +466,16 @@ class RingReducerTest : public ::testing::Test { } void DoReduce() { - col_params_.merge_op = - GetAdd(col_params_.instance.data_type, device_type_, device_); - col_params_.final_op = - GetDiv(col_params_.instance.data_type, device_type_, device_); + merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_); + final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_); + col_params_.merge_op = merge_op_.get(); + col_params_.final_op = final_op_.get(); // Prepare an OpKernelContext. OpKernelContext::Params op_params; op_params.step_id = kStepId; op_params.device = device_; + op_params.cancellation_manager = &parent_->cancellation_manager_; gtl::InlinedVector inputs; inputs.push_back(TensorValue(&tensor_)); op_params.inputs = &inputs; @@ -531,6 +536,8 @@ class RingReducerTest : public ::testing::Test { Tensor tensor_; Device* device_; CollectiveParams col_params_; + std::unique_ptr merge_op_; + std::unique_ptr final_op_; std::unique_ptr ca_; std::unique_ptr ctx_; Status status_; @@ -550,6 +557,7 @@ class RingReducerTest : public ::testing::Test { std::unique_ptr gpu_ring_order_; mutex mu_; int32 reduce_counter_ TF_GUARDED_BY(mu_) = 0; + CancellationManager cancellation_manager_; }; CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index a968aaf09b6..375f809b31b 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -59,13 +60,15 @@ namespace { constexpr char kArgOp[] = "_Arg"; constexpr char kRetvalOp[] = "_Retval"; +} // namespace + // Runs shape inference for the given node using the given ShapeRefiner. // The node must be a sub-node of a function node and the outer_context is // the inference context of that function node in the outer graph. -Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, - InferenceContext* outer_context) { - TF_RETURN_IF_ERROR(refiner->AddNode(node)); - InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node)); +Status ShapeRefiner::InferShapesForFunctionSubNode( + const Node* node, InferenceContext* outer_context) { + TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context)); + InferenceContext* node_context = CHECK_NOTNULL(GetContext(node)); if (StringPiece(node->type_string()) == kArgOp) { // Handle special node: function input. @@ -126,8 +129,6 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, return Status::OK(); } -} // namespace - // TODO(cwhipkey): When an inference context inside function has // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i) // set when input(i) is an _Arg op, then this request should propagate to @@ -167,8 +168,8 @@ Status ShapeRefiner::InferShapesForFunction( auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, &inference_status](const Node* node) { if (!inference_status.ok()) return; - inference_status = InferShapesForFunctionSubNode( - node, this, outer_context->get_context()); + inference_status = + InferShapesForFunctionSubNode(node, outer_context->get_context()); function_nodes.insert(node); }; @@ -187,6 +188,11 @@ Status ShapeRefiner::InferShapesForFunction( } Status ShapeRefiner::AddNode(const Node* node) { + return AddNodeInternal(node, /*outer_context=*/nullptr); +} + +Status ShapeRefiner::AddNodeInternal( + const Node* node, shape_inference::InferenceContext* outer_context) { // Create the inference context for this node with the existing input shapes. std::unique_ptr ic(new InferenceContext( graph_def_version_, node->def(), node->op_def(), @@ -240,7 +246,7 @@ Status ShapeRefiner::AddNode(const Node* node) { new ExtendedInferenceContext(std::move(ic), node)); // Run the shape inference function, and return if there was an error. - TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get())); + TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context)); // Store the resulting context object in the map. node_to_context_[node].swap(ec); @@ -385,25 +391,25 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { return RunShapeFn(node, op_reg_data, node_ext_context); } -Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, - int dst_idx, bool* evaluated, - Tensor* result) { +Status ShapeRefiner::EvaluateConstantTensorForEdge( + const Node* node, int dst_idx, bool* evaluated, Tensor* result, + InferenceContext* outer_context) { *evaluated = false; const Edge* input_edge; TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge)); OutputTensor tensor(input_edge->src(), input_edge->src_output()); - return EvaluateConstantTensor(tensor, *this, *ops_registry_, - graph_def_version_, evaluated, result, - &graph_runner_, &const_tensor_map_, - kMaxTensorSize, disable_constant_propagation_); + return EvaluateConstantTensor( + tensor, *this, *ops_registry_, graph_def_version_, evaluated, result, + &graph_runner_, &const_tensor_map_, kMaxTensorSize, + disable_constant_propagation_, outer_context); } -Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node, - int dst_idx, bool* evaluated, - int64* result) { +Status ShapeRefiner::EvaluateConstantIntScalarEdge( + const Node* node, int dst_idx, bool* evaluated, int64* result, + shape_inference::InferenceContext* outer_context) { Tensor scalar; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar)); + TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated, + &scalar, outer_context)); if (*evaluated) { if (scalar.NumElements() != 1) { return errors::InvalidArgument( @@ -424,9 +430,9 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node, return Status::OK(); } -Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, - const Node* node, int dst_idx, - ShapeHandle* result) { +Status ShapeRefiner::ConstantPartialShape( + InferenceContext* target_context, const Node* node, int dst_idx, + ShapeHandle* result, shape_inference::InferenceContext* outer_context) { const Edge* input_edge; TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge)); @@ -437,8 +443,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, if (src_context->Value(src_context->Rank(src_shape)) == 0) { Tensor t; bool evaluated = false; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t)); + TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, + &t, outer_context)); if (!evaluated) { return errors::InvalidArgument( "Received a shape scalar with unknown static value. A static value " @@ -471,7 +477,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, // a float. Tensor t; bool evaluated = false; - if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) { + if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t, + outer_context) + .ok()) { if (evaluated && target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) { return Status::OK(); @@ -481,7 +489,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, // Then try to infer partial shape from the input to the cast tensor. ShapeHandle pre_cast_shape; if (!ConstantPartialShape(target_context, input_edge->src(), 0, - &pre_cast_shape) + &pre_cast_shape, outer_context) .ok()) { TF_RETURN_IF_ERROR( target_context->MakeShapeFromTensor(nullptr, src_shape, result)); @@ -510,8 +518,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, for (int i = 0; i < src_context->num_inputs(); ++i) { int64 size; bool evaluated; - TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i, - &evaluated, &size)); + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge( + input_edge->src(), i, &evaluated, &size, outer_context)); if (evaluated) { dims.push_back(size < 0 ? target_context->UnknownDim() : target_context->MakeDim(size)); @@ -531,7 +539,7 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, if (i == concat_dim) continue; ShapeHandle sub_result; TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(), - i, &sub_result)); + i, &sub_result, outer_context)); if (!target_context->RankKnown(sub_result)) { // Failed to evaluate. Treat the output as completely unknown. // TODO(cwhipkey): we could rely on all inputs being the same rank, so @@ -543,8 +551,8 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, target_context->Concatenate(*result, sub_result, result)); } } else if (src_op == "StridedSlice") { - TF_RETURN_IF_ERROR( - PartialStridedSliceShape(input_edge->src(), src_context, result)); + TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context, + result, outer_context)); } else if (src_op == "VariableShape") { auto* handle_data = src_context->input_handle_shapes_and_types(0); if (handle_data != nullptr && !handle_data->empty()) { @@ -555,17 +563,17 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, } else { Tensor t; bool evaluated = false; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t)); + TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, + &t, outer_context)); TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor( evaluated ? &t : nullptr, src_shape, result)); } return Status::OK(); } -Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, - InferenceContext* ctx, - ShapeHandle* result) { +Status ShapeRefiner::PartialStridedSliceShape( + Node* slice_node, InferenceContext* ctx, ShapeHandle* result, + shape_inference::InferenceContext* outer_context) { // Only attempt to evaluate if begin/end/strides all are scalars. for (int i = 1; i <= 3; ++i) { ShapeHandle input_shape = ctx->input(i); @@ -600,8 +608,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, if (begin_mask == 1) { begin = 0; } else { - TF_RETURN_IF_ERROR( - EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin)); + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, + &begin, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); return Status::OK(); @@ -612,8 +620,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, if (end_mask == 1) { end = std::numeric_limits::max(); } else { - TF_RETURN_IF_ERROR( - EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end)); + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, + &end, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); return Status::OK(); @@ -621,8 +629,8 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, } int64 stride; - TF_RETURN_IF_ERROR( - EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride)); + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, + &stride, outer_context)); if (!evaluated) { *result = ctx->UnknownShape(); return Status::OK(); @@ -630,14 +638,16 @@ Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, // Apply stride to input interpreted as a partial shape. ShapeHandle input; - TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input)); + TF_RETURN_IF_ERROR( + ConstantPartialShape(ctx, slice_node, 0, &input, outer_context)); TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result)); return Status::OK(); } Status ShapeRefiner::RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, - ExtendedInferenceContext* ec) { + ExtendedInferenceContext* ec, + InferenceContext* outer_context) { // This will be filled in with real data in a second pass. std::vector input_tensors(node->num_inputs(), nullptr); std::vector real_tensors(node->num_inputs()); @@ -719,8 +729,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node, Tensor result; bool evaluated = false; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, i, &evaluated, &result)); + TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge( + node, i, &evaluated, &result, outer_context)); if (evaluated) { real_tensors[i] = result; input_tensors[i] = &real_tensors[i]; @@ -736,7 +746,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node, input_tensors_as_shapes.resize(i + 1); } ShapeHandle s; - TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s)); + TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context)); input_tensors_as_shapes[i] = s; rerun_shape_fn = true; } diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index c83bd81705b..e298700a0b0 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -184,17 +184,56 @@ class ShapeRefiner { AttrSlice attributes, ExtendedInferenceContext* outer_context); + // Performs shape inference for a node inside a function. + // + // 'outer_context' is the 'InferenceContext' for the function's call op. + Status InferShapesForFunctionSubNode( + const Node* node, shape_inference::InferenceContext* outer_context); + + // Performs validation of 'node' and runs 'node's shape function, + // storing its shape outputs. + // + // All inputs of 'node' must be added to ShapeRefiner prior to + // adding 'node'. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + // + // Returns an error if: + // - the shape function for 'node' was not registered. + // - 'node' was added before its inputs. + // - The shape inference function returns an error. + Status AddNodeInternal(const Node* node, + shape_inference::InferenceContext* outer_context); + // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge // value can be evaluated, 'evaluated' is set to true and the value returned // in 'result'. Otherwise 'evaluated' is set to false. - Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, - bool* evaluated, Tensor* result); + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + Status EvaluateConstantTensorForEdge( + const Node* node, int dst_idx, bool* evaluated, Tensor* result, + shape_inference::InferenceContext* outer_context); // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input // tensors. The caller is responsible for checking that the specified edge is // scalar and int32 or int64. - Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx, - bool* evaluated, int64* result); + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + Status EvaluateConstantIntScalarEdge( + const Node* node, int dst_idx, bool* evaluated, int64* result, + shape_inference::InferenceContext* outer_context); // This function tries to materialize as much information about the 'node''s // dst_idx input as a statically computable shape, and the result may be @@ -217,17 +256,39 @@ class ShapeRefiner { // // is used when creating new DimensionHandle and ShapeHandle // objects. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. Status ConstantPartialShape(shape_inference::InferenceContext* target_context, const Node* node, int dst_idx, - shape_inference::ShapeHandle* result); + shape_inference::ShapeHandle* result, + shape_inference::InferenceContext* outer_context); // Implementation of ConstantPartialShape for StridedSlice nodes. - Status PartialStridedSliceShape(Node* slice_node, - shape_inference::InferenceContext* ctx, - shape_inference::ShapeHandle* result); + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + Status PartialStridedSliceShape( + Node* slice_node, shape_inference::InferenceContext* ctx, + shape_inference::ShapeHandle* result, + shape_inference::InferenceContext* outer_context); + // Runs the shape function registered for the node's op type. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, - ExtendedInferenceContext* ec); + ExtendedInferenceContext* ec, + shape_inference::InferenceContext* outer_context = nullptr); int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 44fa5bf2d3a..02cd53221d4 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/common_runtime/threadpool_device.h" +#include "absl/base/call_once.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/scoped_allocator.h" #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" @@ -55,18 +55,14 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, if (DisableMKL()) return; #ifdef _OPENMP const char* user_omp_threads = getenv("OMP_NUM_THREADS"); + static absl::once_flag omp_setting_flag; if (user_omp_threads == nullptr) { // OMP_NUM_THREADS controls MKL's intra-op parallelization // Default to available physical cores const int mkl_intra_op = port::NumSchedulableCPUs(); const int ht = port::NumHyperthreadsPerCore(); - omp_set_num_threads((mkl_intra_op + ht - 1) / ht); - } else { - uint64 user_val = 0; - if (strings::safe_strtou64(user_omp_threads, &user_val)) { - // Superflous but triggers OpenMP loading - omp_set_num_threads(user_val); - } + absl::call_once(omp_setting_flag, omp_set_num_threads, + (mkl_intra_op + ht - 1) / ht); } #endif // _OPENMP #endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 2973548ab19..4e9d4bcba49 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -94,6 +94,23 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol, stub = WorkerService::NewStub(channel); return Status::OK(); } + +void PrepareGraph(GraphDef* graph) { + for (NodeDef& node : *graph->mutable_node()) { + for (const auto& op : kNodeNameSharingOps) { + // Set `use_node_name_sharing` to `true` so that resources aren't deleted + // prematurely. Otherwise, resources may be deleted when their ops are + // deleted at the end of the GraphRunner::Run used by standalone::Dataset. + if (node.op() == op) { + (*node.mutable_attr())["use_node_name_sharing"].set_b(true); + } + if (!node.device().empty()) { + *node.mutable_device() = ""; + } + } + } + StripDevicePlacement(graph->mutable_library()); +} } // namespace DataServiceDispatcherImpl::DataServiceDispatcherImpl( @@ -324,23 +341,16 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( TF_RETURN_IF_ERROR(CheckStarted()); uint64 fingerprint; DatasetDef dataset_def = request->dataset(); - TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint)); - // Set `use_node_name_sharing` to `true` so that resources aren't deleted - // prematurely. Otherwise, resources may be deleted when their ops are - // deleted at the end of the GraphRunner::Run used by standalone::Dataset. - for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) { - for (const auto& op : kNodeNameSharingOps) { - if (node.op() == op) { - (*node.mutable_attr())["use_node_name_sharing"].set_b(true); - } - } - } + GraphDef* graph = dataset_def.mutable_graph(); + PrepareGraph(graph); + TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint)); + mutex_lock l(mu_); #if defined(PLATFORM_GOOGLE) - VLOG_LINES(4, absl::StrCat("Registering dataset graph: ", - dataset_def.graph().DebugString())); + VLOG_LINES(4, + absl::StrCat("Registering dataset graph: ", graph->DebugString())); #else - VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString(); + VLOG(4) << "Registering dataset graph: " << graph->DebugString(); #endif std::shared_ptr dataset; Status s = state_.DatasetFromFingerprint(fingerprint, dataset); @@ -644,7 +654,14 @@ Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request, mutex_lock l(mu_); VLOG(3) << "Looking up tasks for job client id " << request->job_client_id(); std::shared_ptr job; - TF_RETURN_IF_ERROR(state_.JobForJobClientId(request->job_client_id(), job)); + Status s = state_.JobForJobClientId(request->job_client_id(), job); + if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) { + return errors::NotFound( + "Unknown job client id ", request->job_client_id(), + ". The dispatcher is not configured to be fault tolerant, so this " + "could be caused by a dispatcher restart."); + } + TF_RETURN_IF_ERROR(s); std::vector> tasks; TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks)); for (const auto& task : tasks) { diff --git a/tensorflow/core/data/service/grpc_util.cc b/tensorflow/core/data/service/grpc_util.cc index 73ea384ea60..0551d9537fc 100644 --- a/tensorflow/core/data/service/grpc_util.cc +++ b/tensorflow/core/data/service/grpc_util.cc @@ -31,7 +31,8 @@ Status WrapError(const std::string& message, const ::grpc::Status& status) { return errors::Internal("Expected a non-ok grpc status. Wrapping message: ", message); } else { - return Status(static_cast(status.error_code()), + Status s = FromGrpcStatus(status); + return Status(s.code(), absl::StrCat(message, ": ", status.error_message())); } } diff --git a/tensorflow/core/data/service/split_provider.cc b/tensorflow/core/data/service/split_provider.cc index b3100d52ff1..4ebb25348b6 100644 --- a/tensorflow/core/data/service/split_provider.cc +++ b/tensorflow/core/data/service/split_provider.cc @@ -22,10 +22,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace { -const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. -} // namespace - Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { mutex_lock l(mu_); if (!dispatcher_) { @@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { *end_of_splits); }, "get next split", - /*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros); + /*deadline_micros=*/Env::Default()->NowMicros() + + (timeout_ms_ * EnvTime::kMillisToMicros)); } Status DataServiceSplitProvider::Reset() { diff --git a/tensorflow/core/data/service/split_provider.h b/tensorflow/core/data/service/split_provider.h index 110b9e26ec7..57091de9db1 100644 --- a/tensorflow/core/data/service/split_provider.h +++ b/tensorflow/core/data/service/split_provider.h @@ -28,8 +28,12 @@ namespace data { class DataServiceSplitProvider : public SplitProvider { public: DataServiceSplitProvider(const std::string& address, - const std::string& protocol, int64 job_id) - : address_(address), protocol_(protocol), job_id_(job_id) {} + const std::string& protocol, int64 job_id, + int64 timeout_ms) + : address_(address), + protocol_(protocol), + job_id_(job_id), + timeout_ms_(timeout_ms) {} Status GetNext(Tensor* split, bool* end_of_splits) override; Status Reset() override; @@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider { const std::string address_; const std::string protocol_; const int64 job_id_; + const int64 timeout_ms_; mutex mu_; int64 repetition_ = 0; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 364681ef549..4621e1e8a80 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized( case DISTRIBUTED_EPOCH: { auto split_provider = absl::make_unique( config_.dispatcher_address(), config_.protocol(), - task.task_def.job_id()); + task.task_def.job_id(), config_.dispatcher_timeout_ms()); TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider), &task.iterator)); break; @@ -182,7 +182,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, "Worker has not yet registered with dispatcher."); } auto it = tasks_.find(request->task_id()); - if (it == tasks_.end()) { + if (it == tasks_.end() || it->second->finished) { response->set_end_of_sequence(true); return Status::OK(); } @@ -191,7 +191,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence)); if (end_of_sequence) { VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); - tasks_.erase(request->task_id()); + task->finished = true; pending_completed_tasks_.insert(request->task_id()); task_completion_cv_.notify_one(); } diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 5f05275622b..16a0ba0cd93 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -60,6 +60,7 @@ class DataServiceWorkerImpl { TaskDef task_def; mutex mu; bool initialized TF_GUARDED_BY(mu) = false; + bool finished = false; // TODO(aaudibert): Have standalone::Iterator own a reference to // standalone::Dataset so that we don't need to store the dataset here. std::unique_ptr dataset; diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index d3030f4ca0b..12b957334aa 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -242,6 +242,7 @@ tf_cc_test( cc_library( name = "cancellable_call", + srcs = ["cancellable_call.cc"], hdrs = ["cancellable_call.h"], deps = [ ":call_options", @@ -531,6 +532,7 @@ cc_library( srcs = ["collective_rma_distributed.cc"], hdrs = ["collective_rma_distributed.h"], deps = [ + ":call_options", ":cancellable_call", ":request_id", ":worker_cache", diff --git a/tensorflow/core/distributed_runtime/cancellable_call.cc b/tensorflow/core/distributed_runtime/cancellable_call.cc new file mode 100644 index 00000000000..ed25c3a1947 --- /dev/null +++ b/tensorflow/core/distributed_runtime/cancellable_call.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/distributed_runtime/cancellable_call.h" + +namespace tensorflow { + +void CancellableCall::Start(const StatusCallback& done) { + if (cancel_mgr_ == nullptr) { + IssueCall(done); + return; + } + CancellationToken token = cancel_mgr_->get_cancellation_token(); + const bool not_yet_cancelled = + cancel_mgr_->RegisterCallback(token, [this]() { Cancel(); }); + if (not_yet_cancelled) { + IssueCall([this, token, done](const Status& s) { + cancel_mgr_->DeregisterCallback(token); + done(s); + }); + } else { + done(errors::Cancelled("RPC Request was cancelled")); + } +} + +void CancellableCall::Cancel() { + { + mutex_lock l(mu_); + if (is_cancelled_) { + return; + } + is_cancelled_ = true; + } + opts_.StartCancel(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h index 3d82bef5c80..7311c8e3a44 100644 --- a/tensorflow/core/distributed_runtime/cancellable_call.h +++ b/tensorflow/core/distributed_runtime/cancellable_call.h @@ -29,7 +29,8 @@ class CancellableCall { public: CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, WorkerCacheInterface* wc) - : cancel_mgr_(cancel_mgr), + : is_cancelled_(false), + cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc), wi_(wc_->GetOrCreateWorker(remote_worker_)) {} @@ -38,22 +39,17 @@ class CancellableCall { virtual void IssueCall(const StatusCallback& done) = 0; - void Start(const StatusCallback& done) { - CancellationToken token = cancel_mgr_->get_cancellation_token(); - const bool not_yet_cancelled = - cancel_mgr_->RegisterCallback(token, [this]() { opts_.StartCancel(); }); - if (not_yet_cancelled) { - IssueCall([this, token, done](const Status& s) { - cancel_mgr_->DeregisterCallback(token); - done(s); - }); - } else { - done(errors::Cancelled("RPC Request was cancelled")); - } - } + void Start(const StatusCallback& done); + + // Cancels the RPC if it's not cancelled yet. This must be called after + // Start(). This is normally used if there's a needed to cancel the RPC from a + // sideband. If appliable, pass a cancellation manager to the constructor + // instead of using this method. + void Cancel() TF_LOCKS_EXCLUDED(mu_); protected: - mutable mutex mu_; + mutex mu_; + bool is_cancelled_; CancellationManager* const cancel_mgr_; // Not owned const string remote_worker_; WorkerCacheInterface* const wc_; // Not owned diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 238d29065d2..9466c8ef96b 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -295,19 +295,30 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( CompleteGroupCall* call = new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr, group_leader_, worker_cache_); - call->Start([this, device, cp, call, done](const Status& s) { - if (s.ok()) { - Status status = UpdateGroupCache(call->resp_); - if (status.ok()) { - CompleteGroupLocal(device, cp, done); - } else { - done(status, nullptr); - } - } else { - done(s, nullptr); - } + CancellationToken abortion_token = + abortion_cancel_mgr_.get_cancellation_token(); + bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( + abortion_token, [call] { call->Cancel(); }); + if (already_aborted) { + done(errors::Cancelled("collective ops already aborted"), nullptr); delete call; - }); + return; + } + call->Start( + [this, device, cp, call, abortion_token, done](const Status& s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); + if (s.ok()) { + Status status = UpdateGroupCache(call->resp_); + if (status.ok()) { + CompleteGroupLocal(device, cp, done); + } else { + done(status, nullptr); + } + } else { + done(s, nullptr); + } + delete call; + }); return; } else { return CompleteGroupLocal(device, cp, done); @@ -373,7 +384,17 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( CompleteInstanceCall* call = new CompleteInstanceCall( cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr, group_leader_, worker_cache_); - call->Start([this, device, gr, cp, call, done](Status s) { + CancellationToken abortion_token = + abortion_cancel_mgr_.get_cancellation_token(); + bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( + abortion_token, [call] { call->Cancel(); }); + if (already_aborted) { + done(errors::Cancelled("collective ops already aborted")); + delete call; + return; + } + call->Start([this, device, gr, cp, call, abortion_token, done](Status s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); if (s.ok()) { s = UpdateInstanceCache(gr, cp, call->resp_); } @@ -388,4 +409,19 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( } } +void CollectiveParamResolverDistributed::StartAbort(const Status& s) { + { + mutex_lock l(status_mu_); + if (!status_.ok()) { + VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring " + "subsequent abortion with status: " + << s; + return; + } + status_ = s; + } + StartAbortLocal(s); + abortion_cancel_mgr_.StartCancel(); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h index 89f923a800b..97445fa6cfd 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/status.h" @@ -47,6 +48,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { CancellationManager* cancel_mgr, const StatusCallback& done) override; + void StartAbort(const Status& s) override; + protected: // Returns the cached group iff there's an entry for this group_key in the // local group_table_; returns nullptr otherwise. @@ -87,6 +90,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { WorkerCacheInterface* worker_cache_; // Not owned const string group_leader_; + CancellationManager abortion_cancel_mgr_; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index f08f7a3275d..1c62b17fe54 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface { CollectiveParamResolverDistributed* cpres) : name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { std::vector dev_attr; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index 1861262e9b1..29fcd82a4df 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -19,9 +19,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -78,12 +80,12 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( const string& key, Device* to_device, DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, - const StatusCallback& done) { + CancellationManager* cancellation_manager, const StatusCallback& done) { if (peer_is_local) { CollectiveRemoteAccessLocal::RecvFromPeer( peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, - done); + cancellation_manager, done); return; } @@ -166,15 +168,27 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( recv_buf_callback(s); return; } - state->call.reset( - new RecvBufCall(step_id_, peer_device, peer_task, key, to_device, - to_device_ctx, to_alloc_attr, to_tensor, client_locality, - state->server_attributes, &cancel_mgr_, worker_cache_)); - state->call->Start(recv_buf_callback); + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_attributes, + cancellation_manager, worker_cache_)); + CancellationToken abortion_token = + abortion_cancel_mgr_.get_cancellation_token(); + bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( + abortion_token, [state] { state->call->Cancel(); }); + if (already_aborted) { + recv_buf_callback(errors::Cancelled("collective ops already aborted")); + } else { + state->call->Start([this, abortion_token, + done = std::move(recv_buf_callback)](const Status& s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); + done(s); + }); + } } void CollectiveRemoteAccessDistributed::CheckPeerHealth( - const string& peer_task, const StatusCallback& done) { + const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) { if (peer_task == task_name_) { // Fast path if the peer is the worker itself. done(Status::OK()); @@ -191,13 +205,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( "valid form is /job:xxx/replica:0/task:N")); return; } + auto opts = new CallOptions(); + opts->SetTimeout(timeout_in_ms); auto req = new GetStatusRequest(); auto resp = new GetStatusResponse(); - // We're not using Cancellable call because GetStatusAsync doesn't support - // cancellation yet. + // Note that fail_fast is not always respected, so we set a timeout as well. + // We're not using CancellableCall since check health shouldn't need to be + // cancelled. wi->GetStatusAsync( - req, resp, /*fail_fast*/ true, - [this, req, resp, wi, peer_task, done](Status s) { + opts, req, resp, /*fail_fast*/ true, + [this, opts, req, resp, wi, peer_task, done](Status s) { std::vector cached_attrs; if (s.ok()) { s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs); @@ -222,6 +239,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( // first collective. s = Status::OK(); } + delete opts; delete req; delete resp; worker_cache_->ReleaseWorker(peer_task, wi); @@ -231,7 +249,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { CollectiveRemoteAccessLocal::StartAbort(s); - cancel_mgr_.StartCancel(); + abortion_cancel_mgr_.StartCancel(); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h index ed4d448afd9..e3e61e537f7 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ #include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/unbounded_work_queue.h" @@ -42,9 +43,10 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) override; - void CheckPeerHealth(const string& peer_task, + void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) override; void StartAbort(const Status& s) override; @@ -54,7 +56,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { // Ownership of `work_queue_` is shared between `this` and // `CollectiveExecutorMgr`. std::shared_ptr work_queue_; - CancellationManager cancel_mgr_; + CancellationManager abortion_cancel_mgr_; string task_name_; }; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index 454111eb1b6..74282beff1f 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface { // worker is supposed to have. BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; } - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { if (is_failed_) { @@ -126,7 +126,8 @@ class FakeWorker : public TestWorkerInterface { } done(s); if (h) BufRendezvous::DoneWithHook(h); - }); + }, + nullptr /*cancellation_manager*/); } private: @@ -311,7 +312,8 @@ TEST_F(CollRMADistTest, ProdFirstOK) { [&producer_note, &producer_status](const Status& s) { producer_status.Update(s); producer_note.Notify(); - }); + }, + nullptr /*cancellation_manager*/); Device* dst_device = nullptr; string dev_name = "CPU:0"; TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device)); @@ -322,6 +324,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) { false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, + nullptr /*cancellation_manager*/, [&consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); @@ -351,6 +354,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) { false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, + nullptr /*cancellation_manager*/, [&consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); @@ -361,7 +365,8 @@ TEST_F(CollRMADistTest, ConsFirstOK) { [&producer_note, &producer_status](const Status& s) { producer_status.Update(s); producer_note.Notify(); - }); + }, + nullptr /*cancellation_manager*/); consumer_note.WaitForNotification(); TF_EXPECT_OK(consumer_status); producer_note.WaitForNotification(); @@ -384,6 +389,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) { false, // peer_is_local kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, + nullptr /*cancellation_manager*/, [&consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); @@ -411,6 +417,7 @@ TEST_F(CollRMADistTest, WorkerRestart) { false, // peer_is_local buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, + nullptr /*cancellation_manager*/, [&consumer_status, &consumer_note](const Status& s) { consumer_status = s; consumer_note.Notify(); @@ -421,7 +428,8 @@ TEST_F(CollRMADistTest, WorkerRestart) { [&producer_note, &producer_status](const Status& s) { producer_status.Update(s); producer_note.Notify(); - }); + }, + nullptr /*cancellation_manager*/); consumer_note.WaitForNotification(); TF_EXPECT_OK(consumer_status); producer_note.WaitForNotification(); @@ -437,6 +445,7 @@ TEST_F(CollRMADistTest, WorkerRestart) { false, // peer_is_local buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, + nullptr /*cancellation_manager*/, [&consumer_status, &post_restart_note](const Status& s) { consumer_status = s; post_restart_note.Notify(); @@ -450,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -463,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -479,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -496,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -511,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2138ecdfe95..ff44642c68e 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext( auto iter = contexts_.find(context_id); if (iter == contexts_.end()) { *server_context = nullptr; - return errors::InvalidArgument(strings::Printf( + return errors::Unavailable(strings::Printf( "Unable to find a context_id matching the specified one " "(%llu). Perhaps the worker was restarted, or the context was GC'd?", static_cast(context_id))); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 9d35ddf08f7..4a97be5c0c4 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) { // Unable to handle the request since there is no eager context. Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, &remote_enqueue_response); - EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); + EXPECT_EQ(error::UNAVAILABLE, status.code()); EXPECT_TRUE(absl::StrContains( status.error_message(), "Unable to find a context_id matching the specified one")); @@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) { Status status = eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); - EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + EXPECT_EQ(status.code(), error::UNAVAILABLE); EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", status.error_message()); diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 05a9072894e..bb9b074858a 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, } } }; - wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb); + wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp, + /*fail_fast=*/false, cb); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 14a358e8ac2..97dc8257750 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -94,6 +94,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_interface", diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 07ab6c69d2e..51a48fe88ec 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -127,10 +127,10 @@ class GrpcRemoteMaster : public MasterInterface { ::grpc::Status (MasterServiceStub::*pfunc)( ::grpc::ClientContext*, const Request&, Response*), string trace_string = {}) { - int64 timeout_in_ms = call_options->GetTimeout(); - int64 expired_time_micros = Env::Default()->NowMicros(); - if (timeout_in_ms > 0) { - expired_time_micros += (timeout_in_ms * 1000); + absl::Duration timeout = absl::Milliseconds(call_options->GetTimeout()); + absl::Time expired_time = absl::FromUnixMicros(Env::Default()->NowMicros()); + if (timeout > absl::ZeroDuration()) { + expired_time += timeout; } Status s; for (int num_retries = 0;; ++num_retries) { @@ -140,7 +140,7 @@ class GrpcRemoteMaster : public MasterInterface { trace.reset(NewTraceRpc(trace_string, &ctx)); } ctx.set_fail_fast(false); - if (timeout_in_ms > 0) { + if (timeout > absl::ZeroDuration()) { // We do not modify the timeout here to match legacy behavior. However, // this could violate the contract of tensorflow::Session. If we retry // an RPC just before the deadline is exceeded, we will still set the @@ -148,8 +148,7 @@ class GrpcRemoteMaster : public MasterInterface { // being double what was expected. // TODO(b/117162170): investigate fixing this behavior for legacy and // gRPC RPC layers. - ctx.set_deadline(absl::ToChronoTime(absl::Now() + - absl::Milliseconds(timeout_in_ms))); + ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout)); } s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response)); if (!errors::IsUnavailable(s)) { @@ -164,20 +163,20 @@ class GrpcRemoteMaster : public MasterInterface { LOG(WARNING) << "Too many retries, returning last status: " << s; return s; } - const int64 now_micros = Env::Default()->NowMicros(); - const int64 deadline_with_backoff_micros = - now_micros + ComputeBackoffMicroseconds(num_retries); + absl::Time now = absl::FromUnixMicros(Env::Default()->NowMicros()); + const absl::Time deadline_with_backoff = + now + absl::Microseconds(ComputeBackoffMicroseconds(num_retries)); // Wait for a short period of time before retrying the RPC. If our // backoff would put us past the RPC deadline, we truncate it to ensure // our RPC starts before the deadline. - const auto backoff_until = - (timeout_in_ms <= 0 || - expired_time_micros > deadline_with_backoff_micros) - ? deadline_with_backoff_micros - : expired_time_micros; - Env::Default()->SleepForMicroseconds(backoff_until - now_micros); - const int64 now = Env::Default()->NowMicros(); - if (now > expired_time_micros && timeout_in_ms > 0) { + const auto backoff_until = (timeout <= absl::ZeroDuration() || + expired_time > deadline_with_backoff) + ? deadline_with_backoff + : expired_time; + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(backoff_until - now)); + now = absl::FromUnixMicros(Env::Default()->NowMicros()); + if (now > expired_time && timeout > absl::ZeroDuration()) { // If timeout_in_ms is set, exit the retry loop on timeout. return errors::DeadlineExceeded(ctx.debug_error_string()); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index d529abef36c..986ae6adf78 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpcpp/generic/generic_stub.h" #include "grpcpp/grpcpp.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface { ~GrpcRemoteWorker() override {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { - IssueRequest(request, response, getstatus_, std::move(done), nullptr, + IssueRequest(request, response, getstatus_, std::move(done), call_opts, fail_fast); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 041b6e51ffb..d0e67cdcd57 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -49,21 +49,43 @@ class RPCState : public GrpcClientCQTag { : RPCState( stub, cq, method, request, response, std::move(done), call_opts, threadpool, - // 1) If GRPC_FAIL_FAST is specified, fail_fast=$GRPC_FAIL_FAST. - // See b/141948186. - // 2) Otherwise, if the platform is Google, use the fail_fast from - // the caller. See b/140260119. - // 3) Otherwise, use fail_fast=false. - [fail_fast]() -> bool { - bool x; + // 1) If GRPC_FAIL_FAST is set to 'true' or 'false', + // fail_fast=$GRPC_FAIL_FAST. See b/141948186. + // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the + // fail_fast from the caller. See b/140260119. + // + // Current default for PLATFORM_GOOGLE: use caller fail_fast; + // Current default for open source: fail_fast=false. + // + // NOTE: Callers mostly set fail_fast=true to prevent job hanging + // on worker task failures, except a few cases such as GetStatus + // in cluster initialization and collective param resolution. + [fail_fast, &done]() -> bool { + string fail_fast_env; #if defined(PLATFORM_GOOGLE) - TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", fail_fast, &x)); + TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller", + &fail_fast_env)); #else - TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", false, &x)); + TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "false", + &fail_fast_env)); #endif // PLATFORM_GOOGLE - return x; + string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env); + if (fail_fast_env_lower == "true") { + return true; + } else if (fail_fast_env_lower == "use_caller") { + return fail_fast; + } else if (fail_fast_env_lower == "false") { + return false; + } else { + string error_message = strings::StrCat( + "Invalid GRPC_FAIL_FAST config: ", fail_fast_env); + LOG(WARNING) << error_message; + done(errors::InvalidArgument(error_message)); + return false; + } }(), - /*timeout_in_ms=*/0, max_retries, target) { + (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries, + target) { } template diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index f833adcf932..723a5130161 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -696,7 +696,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, }; rma->buf_rendezvous()->ConsumeBuf( request->buf_rendezvous_key(), request->src_device(), - request->src_incarnation(), consumer_callback); + request->src_incarnation(), consumer_callback, + /*cancellation_manager=*/nullptr); } void GrpcWorker::LoggingAsync(const LoggingRequest* request, diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index cec09775469..dc9badfedef 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -30,7 +30,7 @@ namespace tensorflow { // testing. class TestWorkerInterface : public WorkerInterface { public: - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { done(errors::Unimplemented("GetStatusAsync")); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index c4dc51ce47d..be14a58ca49 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { StatusGroup::ConfigureLogHistory(); } -void Worker::GetStatusAsync(const GetStatusRequest* request, +void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { const DeviceMgr* dm = env_->device_mgr; diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 273335ff36f..e280cf2447d 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -45,7 +45,7 @@ class Worker : public WorkerInterface { Worker(WorkerEnv* env); virtual ~Worker() {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 9492d1cd31b..7b759eef95b 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -36,7 +36,8 @@ class TensorResponse; // Interface for talking with the TensorFlow Worker service. class WorkerInterface { public: - virtual void GetStatusAsync(const GetStatusRequest* request, + virtual void GetStatusAsync(CallOptions* opts, + const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) = 0; @@ -132,7 +133,7 @@ class WorkerInterface { GetStatusResponse* response) { Status ret; Notification n; - GetStatusAsync(request, response, /*fail_fast=*/true, + GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true, [&ret, &n](const Status& s) { ret = s; n.Notify(); diff --git a/tensorflow/core/example/BUILD b/tensorflow/core/example/BUILD index 19521c0ea82..752a3641b09 100644 --- a/tensorflow/core/example/BUILD +++ b/tensorflow/core/example/BUILD @@ -83,11 +83,9 @@ tf_cc_test( ], ) -filegroup( +alias( name = "example_parser_configuration_testdata", - srcs = [ - "testdata/parse_example_graph_def.pbtxt", - ], + actual = "//tensorflow/core/example/testdata:example_parser_configuration_testdata", ) tf_proto_library( diff --git a/tensorflow/core/example/testdata/BUILD b/tensorflow/core/example/testdata/BUILD new file mode 100644 index 00000000000..5b021e95762 --- /dev/null +++ b/tensorflow/core/example/testdata/BUILD @@ -0,0 +1,15 @@ +# Example parser test data. + +package( + default_visibility = [ + "//tensorflow/core/example:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "example_parser_configuration_testdata", + srcs = [ + "parse_example_graph_def.pbtxt", + ], +) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 215a4ddccbf..de196f20da9 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -10,6 +10,7 @@ load( "tf_cc_tests", "tf_copts", "tf_cuda_library", + "tf_gen_options_header", ) # buildifier: disable=same-origin-load @@ -155,6 +156,7 @@ exports_files( "op.h", "op_def_builder.h", "op_def_util.h", + "registration_options", "selective_registration.h", "shape_inference.h", ], @@ -216,6 +218,7 @@ filegroup( "reader_op_kernel.h", "register_types.h", "register_types_traits.h", + "registration_options.h", "rendezvous.h", "resource_handle.h", "resource_mgr.h", @@ -400,6 +403,7 @@ filegroup( "queue_interface.h", "reader_interface.h", "register_types_traits.h", + "registration_options.h", "rendezvous.cc", "rendezvous.h", "resource_mgr.cc", @@ -967,9 +971,21 @@ cc_library( ], ) +tf_gen_options_header( + name = "gen_registration_options", + build_settings = { + "//tensorflow:enable_registration_v2": "REGISTRATION_V2", + }, + output_header = "registration_options.h", + template = "registration_options.h.tpl", +) + cc_library( name = "selective_registration", - hdrs = ["selective_registration.h"], + hdrs = [ + "registration_options.h", + "selective_registration.h", + ], deps = tf_selective_registration_deps(), ) @@ -1185,7 +1201,6 @@ tf_cc_tests( "variant_op_registry_test.cc", "variant_test.cc", ], - create_named_test_suite = True, linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index e63efcc15bf..cd4c28e1d2f 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -143,8 +143,8 @@ struct CollectiveParams { int source_rank = -1; // broadcast only // Rank of this device in each subdivision permutation. std::vector subdiv_rank; - std::unique_ptr merge_op; // reduction only - std::unique_ptr final_op; // reduction only + OpKernel* merge_op = nullptr; // reduction only + OpKernel* final_op = nullptr; // reduction only string ToString() const; }; @@ -268,6 +268,7 @@ class CollectiveRemoteAccess { Tensor* to_tensor, const DeviceLocality& client_locality, int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, const StatusCallback& done) = 0; virtual void PostToPeer(const string& peer_device, const string& peer_task, @@ -276,12 +277,13 @@ class CollectiveRemoteAccess { const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, const StatusCallback& done) = 0; // Checks the health of a collective peer. It probes the peer to see if it is // alive. Note that if a peer has restarted, it's considered a different one, // so CheckPeerHealth fails. - virtual void CheckPeerHealth(const string& peer_task, + virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) = 0; virtual BufRendezvous* buf_rendezvous() = 0; diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index ba0c2b84a1a..6aa31909197 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -75,6 +75,11 @@ constexpr char kTFDataResourceTag[] = "tfdata"; class DatasetBase; class SerializationContext; +inline bool IsTFDataFunction(const FunctionDef& func) { + return (func.attr().contains(data::kTFDataFunction) && + func.attr().at(data::kTFDataFunction).b()); +} + // Interface for reading values from a key-value store. // Used for restoring iterator state. This class is thread safe. // Please see comment on IteratorStateWriter for guidance around using the diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 951052b794a..cc985284dac 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -311,6 +311,15 @@ class AsyncInterleaveMany : public Node { (*total_processing_times)[long_name()] = self_processing_time + inputs_processing_time; } + + double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) { + double result = 0; + auto* parameter = gtl::FindOrNull(parameters_, kParallelism); + if (parameter) { + result += (*parameter)->value * AverageBufferedElementSize(); + } + return result; + } }; class KnownRatio : public Node { @@ -593,6 +602,26 @@ class AsyncKnownRatio : public Node { self_processing_time + inputs_processing_time; } + double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) { + double result = 0; + auto* parameter = gtl::FindOrNull(parameters_, kBufferSize); + if (!parameter) { + parameter = gtl::FindOrNull(parameters_, kParallelism); + } + + if (parameter) { + if (ratio_ == 0) { + result += (*parameter)->value * AverageBufferedElementSize(); + } else { + // The estimation is currently not accurate for MapAndBatchDataset for + // the maximum buffer size does not match `num_parallel_calls` + // parameter. + result += (*parameter)->value * AverageBufferedElementSize() / ratio_; + } + } + return result; + } + private: const double ratio_; }; @@ -1067,11 +1096,34 @@ double Node::TotalProcessingTime( } double Node::AverageBufferedElementSize() const { - if (buffered_elements_ == 0) { - return 0; + DCHECK_GE(num_elements_, 0); + DCHECK_GE(buffered_elements_, 0); + if (num_elements_ <= 0) { + if (buffered_elements_ <= 0) { + // If there are no produced elements or buffered elements recorded, return + // 0. + return 0; + } + // If there are no produced elements but some buffered elements, return the + // average size of all buffered elements. + return static_cast(buffered_bytes_) / + static_cast(buffered_elements_); } - return static_cast(buffered_bytes_) / - static_cast(buffered_elements_); + + if (buffered_elements_ <= 0) { + // If there are no buffered elements but some produced elements, return the + // average size of all produced elements. + return static_cast(bytes_produced_) / + static_cast(num_elements_); + } + + // Otherwise, return the mean value of average size of all produced elements + // and average size of all buffered elements. + return (static_cast(bytes_produced_) / + static_cast(num_elements_) + + static_cast(buffered_bytes_) / + static_cast(buffered_elements_)) / + 2.0; } double Node::OutputTimeForInputs( @@ -1275,20 +1327,17 @@ void Node::TotalMaximumBufferedBytesHelper( return; } - double result = 0; - auto* parameter = gtl::FindOrNull(parameters_, kBufferSize); - if (!parameter) { - parameter = gtl::FindOrNull(parameters_, kParallelism); - } - if (parameter) { - result = (*parameter)->value * AverageBufferedElementSize(); - } + double result = MaximumBufferedBytes(); for (auto& input : inputs_) { result += total_bytes->at(input->long_name()); } total_bytes->insert(std::make_pair(long_name(), result)); } +double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) { + return 0; +} + void Model::AddNode(Node::Factory factory, const string& name, std::shared_ptr parent, std::shared_ptr* out_node) { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 5199f2cbeef..a3cd0c06a48 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -251,6 +251,12 @@ class Node { // Returns the node output. Node* output() const { return output_; } + // Returns the parameter value. + double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return parameters_.at(name)->state->value; + } + // Returns the aggregate processing time. int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) { return processing_time_; @@ -517,6 +523,12 @@ class Node { absl::flat_hash_map* total_bytes) const TF_SHARED_LOCKS_REQUIRED(mu_); + // Compute and return the maximum buffered bytes on the node itself. By + // default non-tunable nodes are assumed not to buffer any bytes, so the + // tunable nodes as subclasses are expected to override this method to ensure + // that the optimization algorithm respects the memory budget. + virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_); + // Stores the time passed to the last call to `Node::record_start()` on the // current thread. // diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc index bdfd2c4df2d..97eb720b058 100644 --- a/tensorflow/core/framework/model_test.cc +++ b/tensorflow/core/framework/model_test.cc @@ -131,7 +131,9 @@ TEST_P(AsyncKnownRatioTest, Model) { async_known_many->record_buffer_event(110, 10); EXPECT_EQ(async_known_many->TotalBufferedBytes(), 110); EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(), - 110 * parallelism / 10); + num_inputs_per_output == 0 + ? 110.0 * parallelism / 10 + : 110.0 * parallelism / 10 / num_inputs_per_output); source1->add_processing_time(100); EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0); @@ -385,41 +387,12 @@ TEST(UnknownTest, Model) { EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100); } -class TestNode : public model::Node { - public: - using model::Node::Node; - - virtual ~TestNode() {} - - protected: - std::shared_ptr Clone(std::shared_ptr output) const override - TF_SHARED_LOCKS_REQUIRED(mu_) { - return nullptr; - } - - void InputTimeLocked(absl::flat_hash_map* input_times) - const override TF_SHARED_LOCKS_REQUIRED(mu_) {} - - void OutputTimeLocked( - const absl::flat_hash_map& input_times, - absl::flat_hash_map* gradients, - absl::flat_hash_map* output_times, - absl::flat_hash_map* output_time_gradients) const override - TF_SHARED_LOCKS_REQUIRED(mu_) { - (*output_times)[long_name()] = 0; - } - - void TotalProcessingTimeLocked( - absl::flat_hash_map* processing_times, - absl::flat_hash_map* total_processing_times) override - TF_SHARED_LOCKS_REQUIRED(mu_) { - (*total_processing_times)[long_name()] = 0; - } -}; - -TEST(SetterGetterTest, Node) { - std::shared_ptr node = - std::make_shared(model::Node::Args{-1, "TestNode", nullptr}); +TEST(BufferedBytesTest, Node) { + std::shared_ptr node = model::MakeAsyncInterleaveManyNode( + {-1, "TestNode", nullptr}, + {model::MakeParameter("parallelism", + std::make_shared(3, nullptr, nullptr), + 1, 7)}); EXPECT_EQ(node->id(), -1); EXPECT_EQ(node->name(), "TestNode"); EXPECT_EQ(node->output(), nullptr); @@ -428,16 +401,46 @@ TEST(SetterGetterTest, Node) { EXPECT_EQ(node->buffered_elements(), 0); EXPECT_EQ(node->TotalBufferedBytes(), 0); EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0); - node->record_buffer_event(42, 0); - EXPECT_EQ(node->buffered_bytes(), 42); - EXPECT_EQ(node->TotalBufferedBytes(), 0); - EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0); - EXPECT_EQ(node->buffered_elements(), 0); - node->record_buffer_event(0, 11); - EXPECT_EQ(node->buffered_bytes(), 42); - EXPECT_EQ(node->TotalBufferedBytes(), 0); - EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0); - EXPECT_EQ(node->buffered_elements(), 11); + + node->record_buffer_event(20, 1); + EXPECT_EQ(node->buffered_bytes(), 20); + EXPECT_EQ(node->buffered_elements(), 1); + EXPECT_EQ(node->TotalBufferedBytes(), 20); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 60); + + node->record_buffer_event(10, 1); + EXPECT_EQ(node->buffered_bytes(), 30); + EXPECT_EQ(node->buffered_elements(), 2); + EXPECT_EQ(node->TotalBufferedBytes(), 30); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 45); + + node->record_buffer_event(18, 1); + EXPECT_EQ(node->buffered_bytes(), 48); + EXPECT_EQ(node->buffered_elements(), 3); + EXPECT_EQ(node->bytes_produced(), 0); + EXPECT_EQ(node->num_elements(), 0); + EXPECT_EQ(node->TotalBufferedBytes(), 48); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 48); + + node->record_buffer_event(-20, -1); + node->record_element(); + node->record_bytes_produced(20); + EXPECT_EQ(node->buffered_bytes(), 28); + EXPECT_EQ(node->buffered_elements(), 2); + EXPECT_EQ(node->bytes_produced(), 20); + EXPECT_EQ(node->num_elements(), 1); + EXPECT_EQ(node->TotalBufferedBytes(), 28); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 51); + + node->record_buffer_event(-10, -1); + node->record_element(); + node->record_bytes_produced(10); + EXPECT_EQ(node->buffered_bytes(), 18); + EXPECT_EQ(node->buffered_elements(), 1); + EXPECT_EQ(node->bytes_produced(), 30); + EXPECT_EQ(node->num_elements(), 2); + EXPECT_EQ(node->TotalBufferedBytes(), 18); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 49.5); EXPECT_EQ(node->processing_time(), 0); node->record_start(1); @@ -447,22 +450,32 @@ TEST(SetterGetterTest, Node) { node->add_processing_time(2); EXPECT_EQ(node->processing_time(), 42); - std::shared_ptr input = - std::make_shared(model::Node::Args{-1, "TestInput", node}); + std::shared_ptr input = model::MakeAsyncKnownRatioNode( + {0, "TestInput", node}, 2, + {model::MakeParameter("parallelism", + std::make_shared(5, nullptr, nullptr), + 0, 6)}); EXPECT_EQ(input->output(), node.get()); EXPECT_EQ(node->inputs().size(), 0); node->add_input(input); EXPECT_EQ(node->inputs().size(), 1); EXPECT_EQ(node->inputs().front(), input); - input->record_buffer_event(13, 0); - EXPECT_EQ(node->TotalBufferedBytes(), 0); - EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0); + + input->record_buffer_event(28, 1); + EXPECT_EQ(node->bytes_consumed(), 0); + EXPECT_EQ(node->TotalBufferedBytes(), 46); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5); + + input->record_buffer_event(-28, -1); + input->record_element(); + input->record_bytes_produced(28); + node->record_bytes_consumed(28); + EXPECT_EQ(node->bytes_consumed(), 28); + EXPECT_EQ(node->TotalBufferedBytes(), 18); + EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5); + node->remove_input(input); EXPECT_EQ(node->inputs().size(), 0); - - EXPECT_EQ(node->num_elements(), 0); - node->record_element(); - EXPECT_EQ(node->num_elements(), 1); } // Returns a weighted sum of a prior and the actual processing time. @@ -879,6 +892,63 @@ TEST_P(SelfProcessingTimeTest, Model) { INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest, ::testing::Values(0, 1, 2, 5, 10, 20, 40)); +class OptimizeZeroRamBudgetTest + : public ::testing::TestWithParam {}; + +TEST_P(OptimizeZeroRamBudgetTest, Model) { + const model::AutotuneAlgorithm algorithm = GetParam(); + + std::shared_ptr mutex1 = std::make_shared(); + std::shared_ptr cv1 = + std::make_shared(); + std::shared_ptr node1 = model::MakeAsyncKnownRatioNode( + {1, "1", nullptr}, 2, + {model::MakeParameter("parallelism", + std::make_shared(-1, mutex1, cv1), 1, + 5)}); + node1->record_buffer_event(1, 1); + + std::shared_ptr mutex2 = std::make_shared(); + std::shared_ptr cv2 = + std::make_shared(); + std::shared_ptr node2 = model::MakeAsyncKnownRatioNode( + {2, "2", node1}, 5, + {model::MakeParameter("buffer_size", + std::make_shared(-1, mutex2, cv2), 0, + 6)}); + node2->record_buffer_event(1, 1); + + std::shared_ptr mutex3 = std::make_shared(); + std::shared_ptr cv3 = + std::make_shared(); + std::shared_ptr node3 = model::MakeAsyncInterleaveManyNode( + {3, "3", node2}, + {model::MakeParameter("parallelism", + std::make_shared(-1, mutex3, cv3), 1, + 7)}); + node3->record_buffer_event(1, 1); + + EXPECT_EQ(node1->parameter_value("parallelism"), -1); + EXPECT_EQ(node2->parameter_value("buffer_size"), -1); + EXPECT_EQ(node3->parameter_value("parallelism"), -1); + + model::Model model; + model.AddNode([&node1](model::Node::Args args) { return node1; }, "1", + nullptr, &node1); + model.AddNode([&node2](model::Node::Args args) { return node2; }, "2", node1, + &node2); + model.AddNode([&node3](model::Node::Args args) { return node3; }, "3", node2, + &node3); + + model.Optimize(algorithm, 40, 0, 0); + EXPECT_EQ(node1->parameter_value("parallelism"), 1); + EXPECT_EQ(node2->parameter_value("buffer_size"), 0); + EXPECT_EQ(node3->parameter_value("parallelism"), 1); +} + +INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest, + ::testing::Values(0, 1)); + } // namespace } // namespace model } // namespace data diff --git a/tensorflow/core/framework/registration_options.h.tpl b/tensorflow/core/framework/registration_options.h.tpl new file mode 100644 index 00000000000..375a1088b51 --- /dev/null +++ b/tensorflow/core/framework/registration_options.h.tpl @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_OPTIONS_TMPL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_OPTIONS_TMPL_H_ + +// This header is generated from a template; see the tf_gen_options_header() +// build rule. Template placeholders of the form '#define_option X' result in +// macros of the form 'TF_OPTION_X()'. + +#define_option REGISTRATION_V2 + +#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_OPTIONS_TMPL_H_ diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h index c9bbcb8bfe8..06ea0e00004 100644 --- a/tensorflow/core/framework/selective_registration.h +++ b/tensorflow/core/framework/selective_registration.h @@ -35,6 +35,10 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/registration_options.h" + +#if !TF_OPTION_REGISTRATION_V2() + #ifdef SELECTIVE_REGISTRATION // Experimental selective registration support to reduce binary size. @@ -66,12 +70,20 @@ limitations under the License. !defined(SHOULD_REGISTER_OP_KERNEL)) static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros"); #endif -#else +#else // SELECTIVE_REGISTRATION #define SHOULD_REGISTER_OP(op) true #define SHOULD_REGISTER_OP_GRADIENT true #define SHOULD_REGISTER_OP_KERNEL(clz) true +#endif // SELECTIVE_REGISTRATION + +#else // ! TF_OPTION_REGISTRATION_V2() + +#ifdef SELECTIVE_REGISTRATION +#error TF_OPTION_REGISTRATION_V2(): Compile-time selective registration is not supported #endif +#endif // ! TF_OPTION_REGISTRATION_V2() + namespace tensorflow { // An InitOnStartupMarker is 'initialized' on program startup, purely for the diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 456c1826572..2d81b294372 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -719,7 +719,7 @@ Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape)); - requested_input_tensor_as_partial_shape_[input_idx] = true; + request_input_tensor_as_partial_shape(input_idx); const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size(); if (input_idx < input_tensors_as_shapes_size && input_tensors_as_shapes_[input_idx].IsSet() && @@ -738,7 +738,7 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape)); - requested_input_tensor_as_partial_shape_[input_idx] = true; + request_input_tensor_as_partial_shape(input_idx); const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size(); if (input_idx < input_tensors_as_shapes_size && input_tensors_as_shapes_[input_idx].IsSet() && diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 10b54476d18..a7c72ebe294 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -268,15 +268,31 @@ class InferenceContext { // not available at the time of shape inference. const Tensor* input_tensor(int idx) { // Mark that this idx was requested. - requested_input_tensor_[idx] = true; + request_input_tensor(idx); return input_tensors_[idx]; } + // Notifies the shape refiner that the value of the tensor at index + // is needed. The shape refiner tries to statically compute this tensor, + // and if successful re-runs the shape function with this tensor available + // in the call to 'input_tensor(idx)'. + void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; } + // Returns true iff input_tensor(idx) was called by the shape function. bool requested_input_tensor(int idx) const { return requested_input_tensor_[idx]; } + // Notifies the shape refiner that the value of the tensor at index + // as a partial shape is needed. The shape refiner tries to statically compute + // this, and if successful re-runs the shape function with the + // computed PartialTensorShape available in the call to + // 'MakeShapeFromShapeTensor(idx, handle)' or + // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'. + void request_input_tensor_as_partial_shape(int idx) { + requested_input_tensor_as_partial_shape_[idx] = true; + } + // Returns true if MakeShapeFromInputTensor was called but the constant // input_tensor was not present. bool requested_input_tensor_as_partial_shape(int idx) const { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index bf57e263441..7680bcacba5 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, void OptimizeControlFlowColocation(Graph* graph) { auto visit = [](Node* node) { if (IsSwitch(node)) { + // Pivot Switch nodes (which are also of type Switch) are already placed + // on the CPU and colocated with its inputs that are also already on the + // CPU (or might be placed on GPU but in host memory). + if (HasNodeAttr(node->def(), "_PivotSwitch")) { + DCHECK(node->requested_device().find("CPU") != string::npos); + return; + } for (const Edge* in_edge : node->in_edges()) { if (in_edge->dst_input() == 0) { // Colocate with the data input. diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 9d9727bdb46..efa1cfb0d3c 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -102,6 +102,16 @@ bool inline NativeFormatEnabled() { return native_fmt_enabled; } +// Check if the data_format attribute in the node def represents 5D tensor +bool inline Check5DFormat(const NodeDef& ndef) { + string data_format; + TF_CHECK_OK(GetNodeAttr(ndef, "data_format", &data_format)); + if (data_format.compare("NCDHW") == 0 || data_format.compare("NDHWC") == 0) { + return true; + } + return false; +} + namespace mkl_op_registry { // MKL operators whose kernels are registered with 'MklLayoutDependentOp' label // (e.g., MklConv2D) understand input tensors in MKL layout. These operators diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 7fc74b0aca5..a1af69354e4 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -717,6 +717,7 @@ tf_cuda_cc_test( ":custom_graph_optimizer_registry", ":meta_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -869,6 +870,7 @@ tf_kernel_library( deps = [ ":constant_folding", ":graph_optimizer", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc index 4d324ecbd3d..1288f9695b9 100644 --- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc +++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc @@ -45,6 +45,9 @@ constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3"; constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset"; constexpr char kRebatchDatasetOpName[] = "RebatchDataset"; constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2"; +constexpr char kTensorDatasetOpName[] = "TensorDataset"; +constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset"; +constexpr char kPlaceholderOpName[] = "Placeholder"; constexpr char kNumWorkersAttrName[] = "num_workers"; constexpr char kNumReplicasAttrName[] = "num_replicas"; @@ -68,12 +71,13 @@ constexpr std::array kMultipleInputsDatasetOps = { "ZipDataset" }; -constexpr std::array kPassThroughOps = { +constexpr std::array kPassThroughOps = { "_Retval", "AssertNextDataset", "BatchDataset", "CacheDataset", "ExperimentalMapAndBatchDataset", + "ExperimentalParseExampleDataset", "ExperimentalRebatchDataset", "FilterDataset", "Identity", @@ -413,6 +417,33 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node, return Status::OK(); } +const NodeDef* FindFuncAndTensorSliceDataset( + const NodeDef* node, int64 num_workers, int64 index, + FunctionLibraryDefinition* flib, MutableGraphView* graph, + absl::flat_hash_set* nodes_to_delete) { + if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) { + const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0); + if (input_node->op() == kTensorSliceDatasetOpName || + input_node->op() == kTensorDatasetOpName) { + const NodeDef* next_input_node = + graph_utils::GetInputNode(*input_node, *graph, 0); + if (next_input_node->op() == kPlaceholderOpName) { + return node; + } + } + } + + if (!IsDatasetNodeOfType(*node, kPassThroughOps)) { + return nullptr; + } + + // Sometimes there are other nodes between the last InterleaveDataset and the + // second to last FlatMapDataset, so we need to skip over those. + const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0); + return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib, + graph, nodes_to_delete); +} + Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index, FunctionLibraryDefinition* flib, MutableGraphView* graph, @@ -441,6 +472,39 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index, return Status::OK(); } + // This handles the case for the following subgraph: + // Placeholder -> TensorSliceDataset -> FlatMapDataset -x-> + // (other preprocessing datasets) -> InterleaveDataset + // and then inserting the shard node immediately after the FlatMapDataset. + // + // This is used for some training pipelines where a dataset is created with + // the following code: + // + // def make_dataset_pipeline(): + // file_globs = [...] + // datasets = [] + // for file_glob in file_globs: + // datasets.append(Dataset.list_files(file_glob).map(TFRecordReader)) + // dataset = Dataset.from_tensor_slices(datasets) + // dataset = dataset.flat_map(lambda x: x) + // dataset = ... # additional preprocessing + // dataset = dataset.interleave(lambda x: x, cycle_length=...) + // return dataset + if (IsDatasetNodeOfType(node, kFuncDatasetOps)) { + const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); + const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset( + input_node, num_workers, index, flib, graph, nodes_to_delete); + + if (flat_map_node != nullptr) { + auto fanouts = graph->GetFanouts(*flat_map_node, false); + // FlatMapDataset should only be the input to one other dataset. + if (fanouts.size() == 1) { + return ProcessDatasetSourceNode(graph, *fanouts.begin()->node, + nodes_to_delete, num_workers, index); + } + } + } + // This handles the case where a reader Dataset is contained within a // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example: // @@ -570,7 +634,6 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index, MutableGraphView graph(output); FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library()); - NodeDef* sink_node; TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node)); diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc index 3a336f87f0a..f8b493f06d4 100644 --- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc +++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc @@ -29,7 +29,6 @@ namespace tensorflow { namespace grappler { namespace { -constexpr char kRetValOp[] = "_Retval"; constexpr char kMaxIntraOpParallelismDataset[] = "MaxIntraOpParallelismDataset"; constexpr char kModelDataset[] = "ModelDataset"; @@ -46,17 +45,11 @@ Status DisableIntraOpParallelism::OptimizeAndCollectStats( *output = item.graph; MutableGraphView graph(output); - for (const auto& fetch_name : item.fetch) { - // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, - // because we only want to disable intra op parallelism on the main dataset - // pipeline. - auto fetch = graph.GetNode(fetch_name); - if (fetch == nullptr || fetch->op() == kRetValOp) { - // Heuristic: If the fetch nodes are Retval ops, this item is from a - // function. - return Status::OK(); - } - } + // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, + // because we only want to disable intra op parallelism on the main dataset + // pipeline. + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return Status::OK(); if (item.fetch.size() != 1) { return errors::InvalidArgument( diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc index 4ece16542c8..4159f518e27 100644 --- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc +++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc @@ -31,7 +31,6 @@ namespace { constexpr char kAlgorithm[] = "algorithm"; constexpr char kModelDataset[] = "ModelDataset"; -constexpr char kRetValOp[] = "_Retval"; constexpr int64 HILL_CLIMB = 0; constexpr int64 GRADIENT_DESCENT = 1; @@ -49,17 +48,11 @@ Status EnableGradientDescent::OptimizeAndCollectStats( } MutableGraphView graph(output); - for (const auto& fetch_name : item.fetch) { - // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, - // because we only want to enable gradient descent on the main dataset - // pipeline. - auto fetch = graph.GetNode(fetch_name); - if (fetch == nullptr || fetch->op() == kRetValOp) { - // Heuristic: If the fetch nodes are Retval ops, this item is from a - // function. - return Status::OK(); - } - } + // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, + // because we only want to enable gradient descent on the main dataset + // pipeline. + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return Status::OK(); int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output); NodeDef& model_node = *(output->mutable_node(index)); diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index d70a1ca486e..e09ea575ce4 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -430,9 +430,8 @@ FunctionDef* FuseFunctions( const SetInputFn& set_input, const SetOutputFn& set_output, const SetNodesFn& set_nodes, FunctionDefLibrary* library) { auto has_attrs = [](const FunctionDef& func) { - return !( - func.attr_size() == 0 || - (func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction))); + return !(func.attr_size() == 0 || + (func.attr_size() == 1 && data::IsTFDataFunction(func))); }; if (has_attrs(first_function) || has_attrs(second_function)) { return nullptr; // Functions with attributes are currently not supported. diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 8bc33ea8464..10207de920c 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -27,6 +27,7 @@ namespace graph_utils { namespace { constexpr char kConstOpName[] = "Const"; +constexpr char kRetValOp[] = "_Retval"; template std::vector GetElementIndicesWithPredicate(const Predicate& predicate, @@ -367,6 +368,19 @@ Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item, return Status::OK(); } +bool IsItemDerivedFromFunctionDef(const GrapplerItem& item, + const MutableGraphView& graph_view) { + for (const auto& fetch_name : item.fetch) { + auto fetch = graph_view.GetNode(fetch_name); + if (fetch != nullptr && fetch->op() != kRetValOp) { + // We found a fetch node which is not a `Retval` op. + return false; + } + } + // All fetch nodes are `Retval` ops (or we don't have any fetch nodes). + return true; +} + } // namespace graph_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 87c9831126f..3a397e50106 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -169,6 +169,13 @@ Status EnsureNodeNamesUnique(Graph* g); Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item, NodeDef** fetch_node); +// Returns true if `item` is derived from a `FunctionDef`, false otherwise. +// Currently, we determine this heuristically: If we don't have any fetch nodes +// or all fetch nodes are `Retval` ops, then we consider this item as derived +// from a `FunctionDef`. +bool IsItemDerivedFromFunctionDef(const GrapplerItem& item, + const MutableGraphView& graph_view); + } // namespace graph_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index cd46a7356ac..b4317577cb8 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -118,7 +118,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, for (const auto& name : flib.ListFunctionNames()) { auto* func = flib.Find(name); // Skip non tf.data functions. - if (!func->attr().contains(data::kTFDataFunction)) continue; + if (!data::IsTFDataFunction(*func)) continue; VLOG(3) << "Optimize function: function=" << func->signature().name(); optimized_functions = true; diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc index 211b53ba083..fad2b9f7f67 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.cc +++ b/tensorflow/core/grappler/optimizers/data/slack.cc @@ -33,7 +33,6 @@ namespace grappler { namespace { -constexpr char kRetValOp[] = "_Retval"; constexpr char kPrefetchDatasetOp[] = "PrefetchDataset"; template @@ -116,17 +115,13 @@ Status Slack::OptimizeAndCollectStats(Cluster* cluster, *output = item.graph; MutableGraphView graph(output); - for (const auto& fetch_name : item.fetch) { - // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, - // because we only want to add slack to the prefetch on the main dataset - // pipeline. - auto fetch = graph.GetNode(fetch_name); - if (fetch == nullptr || fetch->op() == kRetValOp) { - // Heuristic: If the fetch nodes are Retval ops, this item is from a - // function. - return Status::OK(); - } - } + + // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, + // because we only want to add slack to the prefetch on the main dataset + // pipeline. + if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) + return Status::OK(); + if (item.fetch.size() != 1) { return errors::InvalidArgument( "Expected only one fetch node but there were ", item.fetch.size(), ": ", diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index ef68b7e7898..3bf465d057f 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -67,6 +67,8 @@ constexpr char kOpConst[] = "Const"; constexpr char kReshape[] = "Reshape"; constexpr char kReshapeConst[] = "ReshapeConst"; constexpr int kRank = 4; +constexpr int kUnknownRank = -1; +constexpr int kInvalidRank = -2; inline bool AttrDataFormatMatch(const utils::MutableNodeView& node, absl::string_view src_data_format, @@ -554,15 +556,23 @@ Status Transposer::UpdateEdge( return Status::OK(); } -bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port, - int n) const { +int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node, + int port) const { const auto* output_shape_attr = node.GetAttr(kAttrOutputShape); if (output_shape_attr == nullptr || output_shape_attr->list().shape_size() <= port) { - return false; + return kInvalidRank; } const auto& shape = output_shape_attr->list().shape(port); - return !shape.unknown_rank() && shape.dim_size() == n; + if (shape.unknown_rank()) { + return kUnknownRank; + } + return shape.dim_size(); +} + +bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port, + int n) const { + return GetFanoutPortRank(node, port) == n; } bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node, @@ -575,14 +585,18 @@ bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node, return true; } -bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port, - int n) const { +int Transposer::GetFaninPortRank(const utils::MutableNodeView& node, + int port) const { if (port < node.NumRegularFanins() && port >= 0) { const auto& regular_fanin = node.GetRegularFanin(port); - return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(), - n); + return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index()); } - return false; + return kInvalidRank; +} + +bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port, + int n) const { + return GetFaninPortRank(node, port) == n; } bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, @@ -719,11 +733,12 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); - const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); - const auto& shape = output_shape_attr->list().shape(0); - const int rank = shape.dim_size(); + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } ScopedDataFormatUpgrader data_format_upgrader(context, rank); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) { + if (!ShouldProcess(*context, *node)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -904,12 +919,12 @@ bool FusedBatchNormGradTransposer::IsTraining( Status FusedBatchNormGradTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormGrad(*node->node())); - const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); - const auto& shape = output_shape_attr->list().shape(0); - const int rank = shape.dim_size(); + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } ScopedDataFormatUpgrader data_format_upgrader(context, rank); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) || - !IsTraining(*node)) { + if (!ShouldProcess(*context, *node) || !IsTraining(*node)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -1089,9 +1104,7 @@ std::vector LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( Status DefaultLayoutAgnosticOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); - const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); - const auto& shape = output_shape_attr->list().shape(0); - const int rank = shape.dim_size(); + const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { return Status::OK(); } @@ -1249,9 +1262,10 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context, Status BinaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBinaryOp(*node->node())); - const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); - const auto& shape = output_shape_attr->list().shape(0); - const int rank = shape.dim_size(); + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) || !IsAfterDstToSrcTransform(*context, *node)) { @@ -1432,13 +1446,12 @@ bool ReduceTransposer::IsReduceAxisSupported( Status ReduceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsReduceOp(*node->node())); - const auto& regular_fanin = node->GetRegularFanin(0); - const auto* output_shape_attr = - regular_fanin.node_view()->GetAttr(kAttrOutputShape); - const auto& shape = output_shape_attr->list().shape(0); - const int rank = shape.dim_size(); + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } ScopedDataFormatUpgrader data_format_upgrader(context, rank); - if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) || + if (!ShouldProcess(*context, *node) || !IsReduceAxisSupported(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index 11a223ee097..bfc67c0633d 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -149,10 +149,12 @@ class Transposer { utils::MutationNewNode* added_node); protected: + int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const; bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port, int n) const; bool IsFanoutPortsRankN(const utils::MutableNodeView& node, absl::Span ports, int n) const; + int GetFaninPortRank(const utils::MutableNodeView& node, int port) const; bool IsFaninPortRankN(const utils::MutableNodeView& node, int port, int n) const; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 8f18dfdeef4..ddfa5522e01 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -93,10 +93,6 @@ bool IsRunOnceOptimizer(const string& name) { name == "auto_mixed_precision_mkl"; } -bool IsTFDataFunction(const FunctionDef& func) { - return func.attr().contains(data::kTFDataFunction); -} - // Creates a function library stub from a real function library: copy only // signatures and attributes of all the function defined in fdef_lib. This stub // can be swapped with real function library in a graph, before passing it to @@ -615,6 +611,63 @@ Status MetaOptimizer::RunOptimizer( return Status::OK(); } +// Propagates `_tf_data_function` attributes from functions to their callees. +void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib, + FunctionDefLibrary& fdef_lib) { + // Collect functions that need the attribute in this set. + absl::flat_hash_set tf_data_functions; + std::function collect_tf_data_functions_dfs = + [&](const string& func_name) -> void { + // Return if we already found and added this function. + if (tf_data_functions.contains(func_name)) return; + + // We only get here if the function is (directly or indirectly) called from + // a tf.data function, so add it to the set. + tf_data_functions.insert(func_name); + + const FunctionDef* func_def = flib.Find(func_name); + // Skip functions that are not reachable from the optimized graph. + if (func_def == nullptr) return; + + // Proceed with DFS for functions called from current function. + for (const NodeDef& node : func_def->node_def()) { + if (flib.Contains(node.op())) { + // This is a function call node. + collect_tf_data_functions_dfs(node.op()); + } + // Check if there are functions in attributes. + for (const auto& attr : node.attr()) { + const AttrValue& attr_value = attr.second; + if (attr_value.has_func()) { + collect_tf_data_functions_dfs(attr_value.func().name()); + } + if (attr_value.has_list()) { + for (const auto& func : attr_value.list().func()) { + collect_tf_data_functions_dfs(func.name()); + } + } + } + } + }; + // Perform DFS for all tf.data functions in `fdef_lib`. + for (const auto& func_def : fdef_lib.function()) { + const string& func_name = func_def.signature().name(); + if (data::IsTFDataFunction(func_def)) + collect_tf_data_functions_dfs(func_name); + } + // Set attribute for tf.data functions. We cannot do this in the DFS directly + // because `FunctionLibraryDefinition` does not seem to provide mutable access + // to a `FunctionDef`. + for (FunctionDef& func_def : *fdef_lib.mutable_function()) { + const string& func_name = func_def.signature().name(); + if (tf_data_functions.contains(func_name) && + !data::IsTFDataFunction(func_def)) { + VLOG(2) << "Marking " << func_name << " as tf.data function"; + (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true); + } + } +} + Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) { const uint64 start_us = Env::Default()->NowMicros(); @@ -636,13 +689,13 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, // remove all the unreachable functions. // TODO(ezhulenev): Construct reachable function library definition directly // from the proto without constructing temporary FunctionLibraryDefinition. + int old_library_size = item.graph.library().function_size(); *item.graph.mutable_library() = minimized_flib(item.graph).ToProto(); + int new_library_size = item.graph.library().function_size(); VLOG(1) << absl::Substitute( "Deleted $0 unreachable functions from the graph (library size = $1)", - item.graph.library().function_size() - - item.graph.library().function_size(), - item.graph.library().function_size()); + old_library_size - new_library_size, new_library_size); // Save a few small fields from item before we move it. bool optimize_function_library = @@ -722,6 +775,8 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, for (const FunctionDef& function : optimized_graph->library().function()) { find_xla_compiled_functions(function.node_def()); } + // Propagate `_tf_data_function` attributes from functions to their callees. + PropagateTFDataAttrs(flib, *optimized_graph->mutable_library()); // Optimize each function only once. absl::flat_hash_set optimized_funcs; @@ -747,8 +802,9 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; - // Skip tf.data functions as they are optimized by tf.data meta optimizer. - if (IsTFDataFunction(func)) continue; + // Skip tf.data functions as they are optimized by tf.data meta optimizer + // and in function instantiation. + if (data::IsTFDataFunction(func)) continue; VLOG(3) << "Optimize function: function=" << func_name << " [" << function_idx++ << " of " diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 595b636c7a9..85f7f911635 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/substitute.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -1016,6 +1017,144 @@ TEST_F(MetaOptimizerTest, CompressConstants) { } } +// Tests for checking expected behavior when skipping tf.data functions in +// meta optimizer. + +// Custom optimizer which counts its calls. +class TfDataTestOptimizer : public CustomGraphOptimizer { + public: + static void InitCount() { cnt_ = 0; } + static int GetCount() { return cnt_; } + + TfDataTestOptimizer() {} + string name() const override { return "tf_data_test_optimizer"; } + bool UsesFunctionLibrary() const override { return false; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + ++cnt_; + *optimized_graph = item.graph; + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static int cnt_; +}; + +int TfDataTestOptimizer::cnt_; + +REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer); + +// Test fixture for parametrized testing. +class TfDataTestFixture + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + is_my_mul_tf_data_ = std::get<0>(GetParam()); + is_my_square_tf_data_ = std::get<1>(GetParam()); + } + void RunTest(); + + private: + // controls which of the functions is flagged as tf.data function + bool is_my_mul_tf_data_ = false; + bool is_my_square_tf_data_ = false; +}; + +TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); } + +// Core test function. +void TfDataTestFixture::RunTest() { + using test::function::NDef; + + // Define function library: + // + // MyMul(x, y) = x * y + // MySquare(x) = MyMul(x, x) + + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/ + {{"z", "mul:z:0"}}); + (*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_); + + FunctionDef square_func = FunctionDefHelper::Create( + "MySquare", {"x:float"}, {"z:float"}, {}, + {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/ + {{"z", "my_mul:z:0"}}); + (*square_func.mutable_attr())[data::kTFDataFunction].set_b( + is_my_square_tf_data_); + + // Tensorflow graph: + // + // a = tf.Placeholder(tf.float); + // square = MySquare(a); // a^2 + GrapplerItem item; + item.id = "tf_graph"; + item.graph = test::function::GDef( + {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + // Calls into function library + NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice), + // Forward outputs + NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)}, + /*funcs=*/ + {mul_func, square_func}); + + // Use only custom optimizer which counts its calls. + TfDataTestOptimizer::InitCount(); + ConfigProto config_proto; + auto& rewriter_config = + *(config_proto.mutable_graph_options()->mutable_rewrite_options()); + rewriter_config.add_optimizers("TfDataTestOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); + + MetaOptimizer optimizer(nullptr, config_proto); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // We expect one graph optimization + one optimization for each non-tf.data + // function. Note that if `MySquare` is flagged as a tf.data function, then + // `MyMul` is implicitly also considered a tf.data function because it is + // called from `MySquare`. + int expected_count = 3; + if (is_my_square_tf_data_) + expected_count -= 2; + else if (is_my_mul_tf_data_) + expected_count -= 1; + EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count); + + // We expect that the tf.data-attribute has been propagated from `MySquare` + // to its callee `MyMul` if the value is `true`. Otherwise, the attribute + // values should be unchanged. + FunctionLibraryDefinition flib(OpRegistry::Global(), output.library()); + const FunctionDef* square_func_after_opt = flib.Find("MySquare"); + const FunctionDef* mul_func_after_opt = flib.Find("MyMul"); + + EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt), + is_my_square_tf_data_); + if (is_my_square_tf_data_ || is_my_mul_tf_data_) { + EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true); + } else { + EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false); + } +} + +INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture, + ::testing::Combine(::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index e8dd1a68bb3..b9bd6430991 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/util.h" #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" @@ -306,6 +307,9 @@ bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) { return IsCpuCompatibleConv2D(&node); } else if (IsDepthwiseConv2dNative(node)) { #ifdef INTEL_MKL + if (DisableMKL()) { + return false; + } return IsCpuCompatibleDepthwiseConv2dNative(&node); #else return false; @@ -660,6 +664,7 @@ bool IsAddWithNoBroadcast(const RemapperContext& ctx, const NodeDef& node) { bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx, const utils::MutableNodeView& node_view, ContractionWithBiasAddAndAdd* matched) { + if (DisableMKL()) return false; // Fusion with AddN is supported only when it has two inputs. // TODO(lyandy): Forward controls for patterns with control dependencies. if (HasControlFaninOrFanout(node_view) || node_view.NumRegularFanins() != 2) @@ -710,6 +715,7 @@ bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx, bool FindContractionWithBiasAndAddActivation( const RemapperContext& ctx, int node_index, ContractionWithBiasAndAddActivation* matched) { + if (DisableMKL()) return false; const auto* node_view = ctx.graph_view.GetNode(node_index); // TODO(lyandy): Forward controls for patterns with control dependencies. if (HasControlFaninOrFanout(*node_view)) return false; @@ -838,6 +844,8 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, #ifndef ENABLE_MKLDNN_V1 // We fuse FusedBatchNorm on GPU or MKL CPU. if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false; +#else + if (DisableMKL()) return false; #endif DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T"); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6c045152434..edf2567a4d9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -172,7 +172,6 @@ tf_kernel_library( "strided_slice_op_gpu_number_types.cu.cc", ], deps = [ - ":bounds_check", ":dense_update_functor", ":inplace_ops", ":ops_util", @@ -181,6 +180,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -279,9 +279,9 @@ tf_kernel_library( "gpu_device_array_gpu.h", ], deps = [ - ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], alwayslink = 0, @@ -334,6 +334,7 @@ cc_library( hdrs = ["conv_2d.h"], deps = [ ":eigen_helpers", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -464,8 +465,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", + "//tensorflow/core/framework:tensor_testutil", ], ) @@ -655,9 +656,9 @@ cc_library( hdrs = ["save_restore_tensor.h"], copts = if_not_windows(["-Wno-sign-compare"]), deps = [ - ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/util/tensor_bundle", ], ) @@ -713,19 +714,8 @@ cc_library( ], ) -alias( - name = "bounds_check", - actual = "//tensorflow/core/framework:bounds_check", - visibility = [":friends"], -) - # Private support libraries --------------------------------------------------- -cc_header_only_library( - name = "bounds_check_lib", - deps = [":bounds_check"], -) - cc_library( name = "gpu_device_array", hdrs = [ @@ -910,7 +900,6 @@ cc_library( # OpKernel libraries ---------------------------------------------------------- ARRAY_DEPS = [ - ":bounds_check", ":concat_lib", ":fill_functor", ":gather_functor", @@ -922,6 +911,7 @@ ARRAY_DEPS = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ] @@ -1419,10 +1409,22 @@ tf_cc_test( ], ) +cc_library( + name = "ragged_tensor_variant", + srcs = ["ragged_tensor_variant.cc"], + hdrs = ["ragged_tensor_variant.h"], + deps = [ + ":cwise_op", + "//tensorflow/core:framework", + ], +) + tf_kernel_library( name = "ragged_tensor_to_variant_op", srcs = ["ragged_tensor_to_variant_op.cc"], deps = [ + ":concat_lib", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1432,6 +1434,7 @@ tf_kernel_library( name = "ragged_tensor_from_variant_op", srcs = ["ragged_tensor_from_variant_op.cc"], deps = [ + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1444,6 +1447,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_to_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1460,6 +1464,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_from_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1491,11 +1496,11 @@ tf_kernel_library( srcs = ["cudnn_rnn_ops.cc"], visibility = ["//visibility:public"], deps = [ - ":bounds_check_lib", ":gpu_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ], @@ -1652,6 +1657,9 @@ tf_cuda_cc_test( name = "conv_ops_test", size = "medium", srcs = ["conv_ops_test.cc"], + tags = [ + "no_cuda_asan", # TODO(b/171342275): re-enable. + ], deps = [ ":conv_ops", ":ops_testutil", @@ -1751,7 +1759,9 @@ tf_cuda_cc_test( name = "depthwise_conv_ops_test", size = "small", srcs = ["depthwise_conv_ops_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171342266): re-enable. + ], deps = [ ":conv_ops", ":ops_testutil", @@ -1885,8 +1895,8 @@ tf_kernel_library( prefix = "gather_functor", visibility = [":friends"], deps = [ - ":bounds_check", "//tensorflow/core:framework", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -2115,10 +2125,10 @@ tf_kernel_library( prefix = "scatter_functor", visibility = [":friends"], deps = [ - ":bounds_check", ":dense_update_functor", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -2230,9 +2240,9 @@ tf_cc_test( deps = [ ":transpose_functor", "//tensorflow/core:framework", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", ], ) @@ -2283,7 +2293,7 @@ tf_kernel_library( name = "ctc_ops", prefix = "ctc", deps = [ - ":bounds_check", + "//tensorflow/core/framework:bounds_check", ":ops_util", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -2352,7 +2362,6 @@ cc_header_only_library( ) DATA_FLOW_DEPS = [ - ":bounds_check", ":concat_lib", ":conditional_accumulator", ":conditional_accumulator_base", @@ -2373,6 +2382,7 @@ DATA_FLOW_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", ] tf_kernel_library( @@ -2496,7 +2506,7 @@ tf_kernel_library( ) DYNAMIC_DEPS = [ - ":bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -2549,7 +2559,6 @@ tf_cc_test( ) LOOKUP_DEPS = [ - ":bounds_check", ":initializable_lookup_table", ":lookup_util", "@com_google_absl//absl/container:flat_hash_map", @@ -2557,6 +2566,7 @@ LOOKUP_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", ] tf_kernel_library( @@ -2714,7 +2724,6 @@ tf_kernel_library( srcs = ["resource_variable_ops.cc"], hdrs = ["resource_variable_ops.h"], deps = [ - ":bounds_check", ":dense_update_functor", ":gather_functor", ":gather_nd_op", @@ -2725,6 +2734,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "@com_google_absl//absl/strings", ], ) @@ -3006,12 +3016,12 @@ tf_kernel_library( ) SAVE_RESTORE_DEPS = [ - ":bounds_check_lib", ":save_restore_tensor", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/util/tensor_bundle", ] @@ -3174,9 +3184,9 @@ tf_kernel_library( "roll_op.h", ], deps = [ - ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -3201,13 +3211,13 @@ tf_cc_test( ) MATH_DEPS = [ - ":bounds_check", ":fill_functor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:math_grad", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ] @@ -3332,6 +3342,8 @@ tf_kernel_library( prefix = "batch_matmul_op", deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([ "//third_party/mkl:intel_binary_blob", + ]) + if_cuda_or_rocm([ + "//tensorflow/core/kernels:gpu_utils", ]), ) @@ -3389,7 +3401,7 @@ tf_kernel_library( name = "fft_ops", prefix = "fft_ops", deps = MATH_DEPS + [ - ] + if_cuda([ + ] + if_cuda_or_rocm([":gpu_utils"]) + if_cuda([ "//tensorflow/core/platform/default/build_config:cufft_plugin", ]), ) @@ -3547,7 +3559,6 @@ tf_cuda_cc_test( size = "small", srcs = ["cwise_ops_test.cc"], deps = [ - ":bounds_check", ":cwise_op", ":nn", ":ops_testutil", @@ -3559,6 +3570,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/framework:bounds_check", ], ) @@ -3825,7 +3837,6 @@ tf_kernel_library( deps = [ ":conv_grad_shape_utils", ":conv_ops_3d_headers", - ":bounds_check", ":conv_2d", ":conv_3d", ":eigen_contraction_kernel", @@ -3839,6 +3850,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/util:image_resizer_state", "//tensorflow/core/util/proto:proto_utils", ] + select({ @@ -3892,12 +3904,12 @@ tf_kernel_library( "depthwise_conv_op_gpu_half.cu.cc", ], deps = [ - ":bounds_check", ":conv_ops", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", ] + if_cuda([ "@local_config_cuda//cuda:cub_headers", "@local_config_cuda//cuda:cudnn_header", @@ -3913,13 +3925,13 @@ tf_kernel_library( ], prefix = "depthwise_conv_grad_op", deps = [ - ":bounds_check", + ":cast_op", ":conv_ops", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", - ":cast_op", + "//tensorflow/core/framework:bounds_check", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]), @@ -3958,7 +3970,6 @@ cc_library( ) NN_DEPS = [ - ":bounds_check", ":conv_2d", ":eigen_contraction_kernel", ":ops_util", @@ -3966,6 +3977,7 @@ NN_DEPS = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:nn_grad", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ] @@ -4215,7 +4227,6 @@ tf_kernel_library( "pooling_ops_3d_gpu.cu.cc", ], deps = [ - ":bounds_check", ":conv_2d", ":conv_3d", ":conv_ops", @@ -4225,6 +4236,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ], @@ -4286,9 +4298,9 @@ tf_kernel_library( ], visibility = [":friends"], deps = [ - ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -4479,7 +4491,6 @@ tf_kernel_library( name = "stateful_random_ops", prefix = "stateful_random_ops", deps = [ - ":bounds_check", ":dense_update_functor", ":fill_functor", ":gather_functor", @@ -4496,6 +4507,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", ], @@ -4505,11 +4517,11 @@ tf_kernel_library( name = "stateless_random_ops", prefix = "stateless_random_ops", deps = [ - ":bounds_check", ":random_op", ":random_poisson_op", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", ], ) @@ -4698,9 +4710,9 @@ tf_kernel_library( name = "sparse_tensor_dense_matmul_op", prefix = "sparse_tensor_dense_matmul_op", deps = SPARSE_DEPS + [ - ":bounds_check", ":fill_functor", "//third_party/eigen3", + "//tensorflow/core/framework:bounds_check", ], ) @@ -4716,8 +4728,8 @@ tf_kernel_library( name = "sparse_xent_op", prefix = "sparse_xent_op", deps = SPARSE_DEPS + [ - ":bounds_check", "//third_party/eigen3", + "//tensorflow/core/framework:bounds_check", ] + if_cuda_or_rocm([ ":reduction_ops", ]) + if_cuda([ @@ -4871,7 +4883,7 @@ cc_library( STATE_DEPS = [ ":assign_op", - ":bounds_check", + "//tensorflow/core/framework:bounds_check", ":fill_functor", ":scatter_functor", "//third_party/eigen3", @@ -5042,7 +5054,7 @@ cc_library( ) STRING_DEPS = [ - ":bounds_check", + "//tensorflow/core/framework:bounds_check", ":string_util", "//third_party/eigen3", "//tensorflow/core:framework", @@ -5259,11 +5271,11 @@ tf_kernel_library( name = "unicode_ops", prefix = "unicode_ops", deps = [ - ":bounds_check", ":string_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", "//third_party/icu/data:conversion_data", "@icu//:common", @@ -5280,11 +5292,11 @@ tf_kernel_library( name = "training_ops", prefix = "training_ops", deps = [ - ":bounds_check", ":training_op_helpers", ":variable_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -5443,11 +5455,11 @@ tf_kernel_library( name = "encode_wav_op", prefix = "encode_wav_op", deps = [ - ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:bounds_check", ], ) @@ -6640,7 +6652,7 @@ tf_cc_binary( "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", "//tensorflow/core:framework", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", ], }), ) @@ -6697,7 +6709,7 @@ cc_binary( ":quantized_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:tensorflow", "//tensorflow/core:test", ], @@ -6902,7 +6914,7 @@ cc_binary( ":ops_util", ":quantized_ops", "//tensorflow/core:framework", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", ], @@ -7042,7 +7054,7 @@ cc_binary( ], "//conditions:default": [ "//tensorflow/core:framework", - "//tensorflow/core:tensor_testutil", + "//tensorflow/core/framework:tensor_testutil", ], }), ) @@ -7362,11 +7374,11 @@ cc_library( "fill_functor.h", ], deps = [ - ":bounds_check", ":meta_support", ":quantization_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", "@gemmlowp", ], diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index d6cc980633f..5c1e0cbe6e4 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_autotune.h" #include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/work_sharder.h" @@ -43,8 +44,13 @@ limitations under the License. #endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION +#endif namespace tensorflow { @@ -219,7 +225,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, const MatMulBCast& bcast, bool use_autotune, + Tensor* out) { typedef ParallelMatMulKernel::IsComplex> ParallelMatMulKernel; bool conjugate_result = false; @@ -275,45 +282,212 @@ se::DeviceMemory AsDeviceMemory(const T* gpu_memory) { return typed; } -class BlasScratchAllocator : public se::ScratchAllocator { +using BlasScratchAllocator = GpuScratchAllocator; + +int64 GetBlasWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); +} + +// Encapsulate all of the shape, dtype etc. information that defines a unique +// batched matmul operation. +class BatchMatmulParameters { public: - using Stream = se::Stream; - using DeviceMemoryBytes = se::DeviceMemory; + BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b, + uint64 m, uint64 n, uint64 k, uint64 batch_count, + bool broadcast_a, bool broadcast_b, DataType dtype_ab, + DataType dtype_cd, bool allow_tf32, int device_id) + : trans_a_(trans_a), + trans_b_(trans_b), + adj_a_(adj_a), + adj_b_(adj_b), + m_(m), + n_(n), + k_(k), + batch_count_(batch_count), + broadcast_a_(broadcast_a), + broadcast_b_(broadcast_b), + dtype_ab_(dtype_ab), + dtype_cd_(dtype_cd), + allow_tf32_(allow_tf32), + device_id_(device_id) { + hash_code_ = trans_a; + hash_code_ = Hash64Combine(hash_code_, trans_b); + hash_code_ = Hash64Combine(hash_code_, adj_a); + hash_code_ = Hash64Combine(hash_code_, adj_b); + hash_code_ = Hash64Combine(hash_code_, m); + hash_code_ = Hash64Combine(hash_code_, n); + hash_code_ = Hash64Combine(hash_code_, k); + hash_code_ = Hash64Combine(hash_code_, batch_count); + hash_code_ = Hash64Combine(hash_code_, broadcast_a); + hash_code_ = Hash64Combine(hash_code_, broadcast_b); + hash_code_ = Hash64Combine(hash_code_, dtype_ab); + hash_code_ = Hash64Combine(hash_code_, dtype_cd); + hash_code_ = Hash64Combine(hash_code_, allow_tf32); + hash_code_ = Hash64Combine(hash_code_, device_id); + } + bool operator==(const BatchMatmulParameters& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } - BlasScratchAllocator(OpKernelContext* context) : context_(context) {} + bool operator!=(const BatchMatmulParameters& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } - int64 GetMemoryLimitInBytes() override { return -1; } - - se::port::StatusOr AllocateBytes( - int64 byte_size) override { - Tensor temporary_memory; - - Status allocation_status(context_->allocate_temp( - DT_UINT8, TensorShape({byte_size}), &temporary_memory)); - if (!allocation_status.ok()) { - return se::port::StatusOr( - DeviceMemoryBytes::MakeFromByteSize(nullptr, 0)); - } - // Hold the reference of the allocated tensors until the end of the - // allocator. - allocated_tensors_.push_back(temporary_memory); - return se::port::StatusOr( - DeviceMemoryBytes::MakeFromByteSize( - temporary_memory.flat().data(), - temporary_memory.flat().size())); + string ToString() const { + // clang-format off + return strings::StrCat( + trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ", + m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ", + broadcast_a_, ", ", broadcast_b_, ", ", + dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_); + // clang-format on } private: - OpKernelContext* context_; - std::vector allocated_tensors_; + typedef std::tuple + ParameterDataType; + + ParameterDataType get_data_as_tuple() const { + return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_, + batch_count_, broadcast_a_, broadcast_b_, dtype_ab_, + dtype_cd_, allow_tf32_, device_id_); + } + + bool trans_a_; + bool trans_b_; + bool adj_a_; + bool adj_b_; + uint64 m_; + uint64 n_; + uint64 k_; + uint64 batch_count_; + bool broadcast_a_; + bool broadcast_b_; + DataType dtype_ab_; + DataType dtype_cd_; + bool allow_tf32_; + int device_id_; + uint64 hash_code_; }; + +bool GetBlasComputationType(const DataType& dtype, bool allow_tf32, + se::blas::ComputationType* compute_type) { + using se::blas::ComputationType; + static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input(); + ComputationType f32_type = + allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32; + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + *compute_type = + use_f32_for_f16_computation ? f32_type : ComputationType::kF16; + return true; + case DT_FLOAT: + *compute_type = f32_type; + return true; + case DT_DOUBLE: + *compute_type = ComputationType::kF64; + return true; + case DT_COMPLEX64: + *compute_type = f32_type; + return true; + case DT_COMPLEX128: + *compute_type = ComputationType::kComplexF64; + return true; + default: + // Unsupported compute_type, return false. + return false; + } +} + +// Thread-safe map from matmul parameters to their corresponding plan and +// algorithms. +template +class BlasLtMatmulPlanMap { + public: + struct PlanAndAlgorithms { + std::unique_ptr plan; + std::vector> algorithms; + }; + + const PlanAndAlgorithms* Find(const Parameters& params) { + mutex_lock lock(mu_); + auto iter = params_plan_map_.find(params); + if (iter == params_plan_map_.end()) { + return nullptr; + } + return &iter->second; + } + const PlanAndAlgorithms* Insert(const Parameters& params, + PlanAndAlgorithms value) { + mutex_lock lock(mu_); + return ¶ms_plan_map_.emplace(params, std::move(value)).first->second; + } + + private: + struct Hasher { + std::size_t operator()(const Parameters& parameter) const { + return parameter.hash(); + } + }; + + mutable mutex mu_; + std::unordered_map params_plan_map_ + GUARDED_BY(mu_); +}; + +template +struct BlasLtPlanMapSingleton { + typedef BlasLtMatmulPlanMap PlanMapType; + static PlanMapType* GetInstance() { + static PlanMapType* instance = new PlanMapType(); + return instance; + } +}; + +typedef BlasLtPlanMapSingleton + BatchMatmulPlanMapSingleton; + +// A dummy type to group matmul autotune results together. +struct BatchMatmulAutoTuneGroup { + static string name() { return "MatmulLt"; } +}; + +typedef AutoTuneSingleton + AutoTuneBatchMatmul; + +template +struct CoefficientType { + typedef Scalar type; +}; +template <> +struct CoefficientType { + typedef float type; +}; + +inline Status FromExecutorStatus(const se::port::Status& s) { + return s.ok() ? Status::OK() + : Status(static_cast(static_cast(s.code())), + s.error_message()); +} + +template +inline Status FromExecutorStatus(const se::port::StatusOr& s) { + return FromExecutorStatus(s.status()); +} + } // namespace template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, const MatMulBCast& bcast, bool use_autotune, + Tensor* out) { se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, se::blas::Transpose::kConjugateTranspose}; @@ -343,10 +517,191 @@ struct LaunchBatchMatMul { auto* a_base_ptr = in_x.template flat().data(); auto* b_base_ptr = in_y.template flat().data(); auto* c_base_ptr = out->template flat().data(); - uint64 a_stride; - uint64 b_stride; - uint64 c_stride; + int64 a_stride; + int64 b_stride; + int64 c_stride; + typedef typename CoefficientType::type Coefficient; + + static const int64 max_scratch_size = GetBlasWorkspaceLimit( + "TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default + + // The BlasLtMatmul routines are only supported from CUDA 11.0 onward. +#if GOOGLE_CUDA && CUDA_VERSION >= 11000 + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + bool requires_mixed_broadcasting = + bcast.IsBroadcastingRequired() && !is_full_broadcast; + if (!requires_mixed_broadcasting) { + bool broadcast_a = bcast.x_batch_size() == 1; + bool broadcast_b = bcast.y_batch_size() == 1; + a_stride = broadcast_a ? 0 : m * k; + b_stride = broadcast_b ? 0 : k * n; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + + DataType dtype = DataTypeToEnum::value; + bool allow_tf32 = tensor_float_32_execution_enabled(); + int device_id = stream->parent()->device_ordinal(); + BatchMatmulParameters matmul_parameters( + trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a, + broadcast_b, dtype, dtype, allow_tf32, device_id); + + static const bool max_autotune_algorithm_count = + MatmulMaxAutotuneAlgorithmCount(); + int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1; + + const auto* plan_and_algorithms = + BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters); + if (!plan_and_algorithms) { + se::blas::DataType blas_dtype = se::blas::ToDataType::value; + se::blas::ComputationType computation_type; + OP_REQUIRES( + context, + GetBlasComputationType(dtype, allow_tf32, &computation_type), + errors::Internal("Unsupported dtype for batched matmul")); + + auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan( + {/*ab_type=*/blas_dtype, + /*c_type=*/blas_dtype, computation_type, + se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, + blas_transpose_b, blas_transpose_a, n, m, k, + /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), + /*ldc=*/static_cast(n), static_cast(batch_size), + b_stride, a_stride, c_stride}); + OP_REQUIRES(context, status_or_plan.ok(), + FromExecutorStatus(status_or_plan)); + std::unique_ptr plan = + status_or_plan.ConsumeValueOrDie(); + + auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms( + plan.get(), max_scratch_size, max_algorithm_count); + OP_REQUIRES(context, status_or_algorithms.ok(), + FromExecutorStatus(status_or_algorithms)); + auto algorithms = status_or_algorithms.ConsumeValueOrDie(); + + plan_and_algorithms = + BatchMatmulPlanMapSingleton::GetInstance()->Insert( + matmul_parameters, {std::move(plan), std::move(algorithms)}); + } + const auto& plan = plan_and_algorithms->plan; + const auto& algorithms = plan_and_algorithms->algorithms; + + // The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take + // alpha and beta with the same type as the matrices. + Scalar alpha(1.0); + Scalar beta(0.0); + + // Note that algorithm_config.algorithm() here is used to refer + // to the index within the algorithms vector, not the algorithm + // itself. + se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm); + if (max_algorithm_count == 1) { + algorithm_config.set_algorithm(0); + } else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters, + &algorithm_config)) { + VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size() + << " algorithms."; + se::blas::ProfileResult best_result; + se::blas::ProfileResult profile_result; + // for (const auto& profile_algorithm : plan_and_algorithms->algorithms) + // { + for (size_t i = 0; i != algorithms.size(); ++i) { + const auto& profile_algorithm = algorithms[i]; + // Create a new scratch allocator with every autotuning run so that + // scratch space is deallocated between runs. + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + + bool cublas_launch_status = + stream + ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], + beta, c_ptrs[0], &scratch_allocator, + profile_algorithm.get(), {}, + &profile_result) + .ok(); + + VLOG(4) << " Autotune algorithm " << i + << " result: " << profile_result.elapsed_time_in_ms() + << " ms, valid=" << profile_result.is_valid() + << ", workspace_size=" << profile_algorithm->workspace_size(); + + if (cublas_launch_status && profile_result.is_valid() && + profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } + + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); + } + // We make sure that each matmul parameter set only gets one pass of + // autotune. If no algorithms works, we add kNoAlgorithm to the autotune + // map. + AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters, + algorithm_config); + } + se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm(); + OP_REQUIRES(context, + 0 <= algorithm_idx && algorithm_idx < algorithms.size(), + errors::Internal("Missing/invalid BatchMatmul algorithm")); + const auto& algorithm = algorithms[algorithm_idx]; + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + bool cublas_launch_status = + stream + ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], + beta, c_ptrs[0], &scratch_allocator, + algorithm.get()) + .ok(); + if (!cublas_launch_status) { + context->SetStatus(errors::Internal( + "Blas batched matmul launch failed : a.shape=(", + bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ", + in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ", + in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } else { // requires mixed broadcasting + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64 i = 0; i < bcast.x_batch_size(); ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + } + for (int64 i = 0; i < bcast.y_batch_size(); ++i) { + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + } + for (int64 i = 0; i < batch_size; ++i) { + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); + b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); + c_ptrs.push_back(&c_device_memory.back()); + } + + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + bool blas_launch_status = + stream + ->ThenBlasGemmBatchedWithScratch( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, + adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, + static_cast(0.0), c_ptrs, n, batch_size, + &scratch_allocator) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + return; +#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000 bool is_full_broadcast = std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; bool use_strided_batched = @@ -388,8 +743,6 @@ struct LaunchBatchMatMul { } } - typedef Scalar Coefficient; - // Blas does // C = A x B // where A, B and C are assumed to be in column major. @@ -399,7 +752,10 @@ struct LaunchBatchMatMul { if (batch_size == 1) { // This is a regular matrix*matrix or matrix*vector multiply. Avoid the // overhead of the scratch allocator and the batch interface. - if (n == 1 && + // Note that the GEMV call here does not support Eigen::half, so we do not + // use this path in that case. A workaround is applied to the pointers + // passed to the call itself to avoid compilation errors. + if (!std::is_same::value && n == 1 && blas_transpose_b != se::blas::Transpose::kConjugateTranspose && blas_transpose_a != se::blas::Transpose::kConjugateTranspose) { // This is a matrix*vector multiply so use GEMV to compute A * b. @@ -410,13 +766,19 @@ struct LaunchBatchMatMul { auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose ? se::blas::Transpose::kNoTranspose : se::blas::Transpose::kTranspose; + // Cast pointers as a workaround for GEMV not supporting Eigen::half + // (this will never actually be executed for Eigen::half). + typedef se::DeviceMemory NonHalfDeviceMemoryType; + NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0])); + NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0])); + NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0])); bool blas_launch_status = stream ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k, adj_x || trans_x ? k : m, - static_cast(1.0), *(a_ptrs[0]), - adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, - static_cast(0.0), c_ptrs[0], 1) + static_cast(1.0), a_ptr, + adj_x || trans_x ? m : k, b_ptr, 1, + static_cast(0.0), &c_ptr, 1) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -459,154 +821,7 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } else { - BlasScratchAllocator scratch_allocator(context); - bool blas_launch_status = - stream - ->ThenBlasGemmBatchedWithScratch( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, - adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs, n, batch_size, - &scratch_allocator) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMMBatched launch failed : a.shape=", - in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } - } -}; - -template <> -struct LaunchBatchMatMul { - static void Launch(OpKernelContext* context, const Tensor& in_x, - const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { - typedef Eigen::half Scalar; - se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, - se::blas::Transpose::kConjugateTranspose}; - const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1); - const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2); - const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2); - const uint64 batch_size = bcast.output_batch_size(); - auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)]; - auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)]; - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - typedef perftools::gputools::DeviceMemory DeviceMemoryType; - std::vector a_device_memory; - std::vector b_device_memory; - std::vector c_device_memory; - std::vector a_ptrs; - std::vector b_ptrs; - std::vector c_ptrs; - a_device_memory.reserve(bcast.x_batch_size()); - b_device_memory.reserve(bcast.y_batch_size()); - c_device_memory.reserve(batch_size); - a_ptrs.reserve(batch_size); - b_ptrs.reserve(batch_size); - c_ptrs.reserve(batch_size); - auto* a_base_ptr = in_x.template flat().data(); - auto* b_base_ptr = in_y.template flat().data(); - auto* c_base_ptr = out->template flat().data(); - - uint64 a_stride; - uint64 b_stride; - uint64 c_stride; - - bool is_full_broadcast = - std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; - bool use_strided_batched = - (!bcast.IsBroadcastingRequired() || is_full_broadcast) && - batch_size > 1; - if (use_strided_batched) { - a_stride = bcast.x_batch_size() != 1 ? m * k : 0; - b_stride = bcast.y_batch_size() != 1 ? k * n : 0; - c_stride = m * n; - a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); - b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); - c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - } else if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; ++i) { - a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); - b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); - c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - } - } else { - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - for (int64 i = 0; i < bcast.x_batch_size(); ++i) { - a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); - } - for (int64 i = 0; i < bcast.y_batch_size(); ++i) { - b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); - } - for (int64 i = 0; i < batch_size; ++i) { - c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); - a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); - b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); - c_ptrs.push_back(&c_device_memory.back()); - } - } - - typedef float Coefficient; - - // Blas does - // C = A x B - // where A, B and C are assumed to be in column major. - // We want the output to be in row-major, so we can compute - // C' = B' x A', where ' stands for transpose (not adjoint). - // TODO(yangzihao): Choose the best of the three strategies using autotune. - if (batch_size == 1) { - // This is a regular matrix*matrix or matrix*vector multiply. Avoid the - // overhead of the scratch allocator and the batch interface. - // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS - bool blas_launch_status = - stream - ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *(b_ptrs[0]), - adj_y || trans_y ? k : n, *(a_ptrs[0]), - adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs[0], n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k)); - } - } else if (use_strided_batched) { - bool blas_launch_status = - stream - ->ThenBlasGemmStridedBatched( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *b_ptrs[0], - adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], - adj_x || trans_x ? m : k, a_stride, - static_cast(0.0), c_ptrs[0], n, c_stride, - batch_size) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMMStridedBatched launch failed : a.shape=", - in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } else { - BlasScratchAllocator scratch_allocator(context); + BlasScratchAllocator scratch_allocator(max_scratch_size, context); bool blas_launch_status = stream ->ThenBlasGemmBatchedWithScratch( @@ -624,6 +839,7 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } +#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000 } }; @@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + use_autotune_ = MatmulAutotuneEnable(); } ~BaseBatchMatMulOp() override {} @@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel { out->shape().DebugString())); LaunchBatchMatMul::Launch( ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false, - /*trans_y=*/false, bcast, &out_reshaped); + /*trans_y=*/false, bcast, use_autotune_, &out_reshaped); } protected: @@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel { private: bool adj_x_; bool adj_y_; + bool use_autotune_; }; // BatchMatMul Op implementation which disallows broadcasting. diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 9bb853f708e..ed6e9a47cad 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -643,8 +643,9 @@ Status Queue::ScheduleWithoutSplit(std::unique_ptr* task) { } profiler::TraceMeProducer trace_me( [task] { - return profiler::TraceMeEncode("Schedule", - {{"size", (*task)->size()}}); + return profiler::TraceMeEncode( + "ScheduleWithoutSplit", + {{"batching_input_task_size", (*task)->size()}}); }, profiler::ContextType::kSharedBatchScheduler, batches_.back()->traceme_context_id()); @@ -672,8 +673,8 @@ Status Queue::ScheduleWithoutSplit(std::unique_ptr* task) { template Status Queue::ScheduleWithSplit(std::unique_ptr* task) { profiler::TraceMe trace_me([task] { - return profiler::TraceMeEncode("ScheduleWithSplit", - {{"size", (*task)->size()}}); + return profiler::TraceMeEncode( + "ScheduleWithSplit", {{"batching_input_task_size", (*task)->size()}}); }); if ((*task)->size() > options_.input_batch_size_limit) { return errors::InvalidArgument("Task size ", (*task)->size(), diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc index 451f2cb96bc..777c5fc8fc7 100644 --- a/tensorflow/core/kernels/collective_nccl_reducer.cc +++ b/tensorflow/core/kernels/collective_nccl_reducer.cc @@ -113,7 +113,7 @@ void NcclReducer::Run(StatusCallback done) { if (final_status.ok()) { final_status = collective_util::ComputeBinOp( col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, - col_params_->final_op.get(), col_ctx_->output, &group_size); + col_params_->final_op, col_ctx_->output, &group_size); } done(final_status); } diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc index f7725151d8a..04399504978 100644 --- a/tensorflow/core/kernels/collective_nccl_test.cc +++ b/tensorflow/core/kernels/collective_nccl_test.cc @@ -248,6 +248,8 @@ class NcclTestBase : public ::testing::Test { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_)) << "Could not find device " << device_name_ << " existing devices " << parent_->dev_mgr_->DebugString(); + merge_op_ = GetAdd(device_); + final_op_ = GetDiv(device_); col_params_.name = parent_->col_params_.name; col_params_.default_rank = rank; col_params_.group = parent_->col_params_.group; @@ -414,6 +416,8 @@ class NcclTestBase : public ::testing::Test { Tensor output_; Device* device_; CollectiveParams col_params_; + std::unique_ptr merge_op_; + std::unique_ptr final_op_; Status status_; }; @@ -459,8 +463,8 @@ class NcclReducerTest : public NcclTestBase { } void InitDevice(DeviceInstance* di) override { - di->col_params_.merge_op = GetAdd(di->device_); - di->col_params_.final_op = GetDiv(di->device_); + di->col_params_.merge_op = di->merge_op_.get(); + di->col_params_.final_op = di->final_op_.get(); } void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); } diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index a3db45dfea6..357ae158ea1 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -49,9 +49,40 @@ static std::unique_ptr BuildOpKernel(OpKernelConstruction* c, return k; } -class CollectiveOpKernel : public AsyncOpKernel { +class CollectiveOpV1Kernel : public AsyncOpKernel { public: - explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + explicit CollectiveOpV1Kernel(OpKernelConstruction* c) + : AsyncOpKernel(c), name_(name()) {} + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + name_), + done); + const CancellationToken token = + c->cancellation_manager()->get_cancellation_token(); + const bool already_cancelled = + !c->cancellation_manager()->RegisterCallback(token, [col_exec]() { + // We must call StartAbort() within the callback. StartAbort() relies + // on resources that may be deallocated if all execution of a graph is + // finished. + col_exec->StartAbort(errors::Cancelled("op cancelled")); + }); + OP_REQUIRES_ASYNC(c, !already_cancelled, + errors::Cancelled("op cancelled ", name_), done); + + auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() { + // Once done() is called, StartAbort() won't have any effect, so we + // don't need to block on the deregistration. Also StartAbort() may call + // done() and DeregisterCallback may deadlock. + c->cancellation_manager()->TryDeregisterCallback(token); + done(); + }; + ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done)); + } // A string encoding instance, frame and iter to be handed off to // the implementation for use in generating RecvBuf keys. @@ -90,14 +121,20 @@ class CollectiveOpKernel : public AsyncOpKernel { return true; } + protected: + virtual void ComputeAsyncImpl(OpKernelContext* c, + CollectiveExecutor* col_exec, + DoneCallback done) = 0; + + string name_; CollectiveParams col_params_; std::vector dependencies_; }; -class CollectiveGatherOpKernel : public CollectiveOpKernel { +class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveGatherOpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c) { + : CollectiveOpV1Kernel(c) { col_params_.instance.type = GATHER_COLLECTIVE; OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES( @@ -119,15 +156,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel { col_params_.group.device_type = c->device_type(); } - void ComputeAsync(OpKernelContext* c, DoneCallback done) override { - CollectiveExecutor* col_exec = c->collective_executor(); - OP_REQUIRES_ASYNC( - c, col_exec, - errors::Internal( - "Failed to get CollectiveExecutor from OpKernelContext for Op ", - col_params_.name), - done); - + protected: + void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, + DoneCallback done) override { auto output_shape = c->input(0).shape(); output_shape.set_dim( 0, output_shape.dim_size(0) * col_params_.group.group_size); @@ -171,10 +202,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU), CollectiveGatherOpKernel); -class CollectiveReduceOpKernel : public CollectiveOpKernel { +class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveReduceOpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c) { + : CollectiveOpV1Kernel(c) { col_params_.instance.type = REDUCTION_COLLECTIVE; OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES( @@ -227,18 +258,15 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel { sub_node.set_device(real_node.device()); SetAttrValue(col_params_.instance.data_type, &(*sub_node.mutable_attr())["T"]); - col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node); - col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node); + merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node); + final_op_ = BuildOpKernel(c, final_op_name, &sub_node); + col_params_.merge_op = merge_op_.get(); + col_params_.final_op = final_op_.get(); } - void ComputeAsync(OpKernelContext* c, DoneCallback done) override { - CollectiveExecutor* col_exec = c->collective_executor(); - OP_REQUIRES_ASYNC( - c, col_exec, - errors::Internal( - "Failed to get CollectiveExecutor from OpKernelContext for Op ", - col_params_.name), - done); + protected: + void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, + DoneCallback done) override { // Allocate output on the first pass through this function. This must be // done immediately, while we're still in the executor thread. Otherwise // the memory is not guaranteed to be unused by any concurrently executing @@ -272,6 +300,8 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel { } private: + std::unique_ptr merge_op_; + std::unique_ptr final_op_; TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel); }; @@ -280,10 +310,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU), CollectiveReduceOpKernel); -class CollectiveBcastSendOpKernel : public CollectiveOpKernel { +class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c) { + : CollectiveOpV1Kernel(c) { col_params_.instance.type = BROADCAST_COLLECTIVE; OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES( @@ -309,14 +339,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel { col_params_.group.device_type = c->device_type(); } - void ComputeAsync(OpKernelContext* c, DoneCallback done) override { - CollectiveExecutor* col_exec = c->collective_executor(); - OP_REQUIRES_ASYNC( - c, col_exec, - errors::Internal( - "Failed to get CollectiveExecutor from OpKernelContext for Op ", - col_params_.name), - done); + protected: + void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, + DoneCallback done) override { // Allocate output on the first pass through this function. This must be // done immediately, while we're still in the executor thread. Otherwise // the memory is not guaranteed to be unused by any concurrently executing @@ -362,10 +387,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU), CollectiveBcastSendOpKernel); -class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { +class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c) { + : CollectiveOpV1Kernel(c) { col_params_.instance.type = BROADCAST_COLLECTIVE; OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES( @@ -391,14 +416,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { col_params_.group.device_type = c->device_type(); } - void ComputeAsync(OpKernelContext* c, DoneCallback done) override { - CollectiveExecutor* col_exec = c->collective_executor(); - OP_REQUIRES_ASYNC( - c, col_exec, - errors::Internal( - "Failed to get CollectiveExecutor from OpKernelContext for Op ", - col_params_.name), - done); + protected: + void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, + DoneCallback done) override { // Allocate output on the first pass through this function. This must be // done immediately, while we're still in the executor thread. Otherwise // the memory is not guaranteed to be unused by any concurrently executing @@ -440,9 +460,8 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU), class CollectiveReduceV2OpKernel : public AsyncOpKernel { public: explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c) - : AsyncOpKernel(c) { - col_params_ = std::make_shared(); - OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); + : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) { + OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_)); string merge_op_name; OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name)); OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name)); @@ -453,32 +472,23 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { } string final_op_name; OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name)); - OP_REQUIRES_OK( - c, c->GetAttr("communication_hint", - &col_params_->instance.impl_details.communication_hint)); - OP_REQUIRES_OK( - c, c->GetAttr("timeout_seconds", - &col_params_->instance.impl_details.timeout_seconds)); + OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_)); + OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_)); // Prepare OpKernels for reduction and final operations. // The merge_op takes two inputs NodeDef sub_node; sub_node.add_input(c->def().input(0)); sub_node.add_input(c->def().input(0)); sub_node.set_device(c->def().device()); - SetAttrValue(col_params_->instance.data_type, - &(*sub_node.mutable_attr())["T"]); - col_params_->merge_op = BuildOpKernel(c, merge_op_name, &sub_node); - col_params_->final_op = BuildOpKernel(c, final_op_name, &sub_node); + SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]); + merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node); + final_op_ = BuildOpKernel(c, final_op_name, &sub_node); - col_params_->name = strings::StrCat(c->def().name(), ": ReduceV2(", - merge_op_name, ",", final_op_name, ")"); - col_params_->group.device_type = c->device_type(); - // Add a default value for subdiv offsets, which is the same as the default - // value in the V1 op's attribute. - col_params_->instance.impl_details.subdiv_offsets.push_back(0); - VLOG(2) << "CollectiveReduceV2 " << this << " name " << col_params_->name - << " communication_hint " - << col_params_->instance.impl_details.communication_hint; + name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",", + final_op_name, ")"); + device_type_ = c->device_type(); + VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_ + << " communication_hint " << communication_hint_; } void ComputeAsync(OpKernelContext* c, DoneCallback done) override { @@ -487,7 +497,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { c, col_exec, errors::Internal( "Failed to get CollectiveExecutor from OpKernelContext for Op ", - col_params_->name), + name_), done); const Tensor& input = c->input(0); const Tensor& group_size = c->input(1); @@ -503,48 +513,49 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { c, instance_key.dims() == 0, errors::Internal("Unexpected dimensions on input instance_key"), done); - auto col_params = std::make_shared(); - col_params->name = col_params_->name; - col_params->group.device_type = col_params_->group.device_type; + auto col_params = new CollectiveParams(); + col_params->name = name_; + col_params->group.device_type = device_type_; col_params->group.group_size = group_size.unaligned_flat()(0); col_params->group.group_key = group_key.unaligned_flat()(0); col_params->instance.type = REDUCTION_COLLECTIVE; col_params->instance.instance_key = instance_key.unaligned_flat()(0); - col_params->instance.data_type = col_params_->instance.data_type; - col_params->instance.impl_details.communication_hint = - col_params_->instance.impl_details.communication_hint; - col_params->instance.impl_details.timeout_seconds = - col_params_->instance.impl_details.timeout_seconds; - col_params->instance.impl_details.subdiv_offsets = - col_params_->instance.impl_details.subdiv_offsets; - col_params->merge_op = std::move(col_params_->merge_op); - col_params->final_op = std::move(col_params_->final_op); + col_params->instance.data_type = data_type_; + col_params->instance.impl_details.communication_hint = communication_hint_; + col_params->instance.impl_details.timeout_seconds = timeout_seconds_; + // Add a default value for subdiv offsets, which is the same as the default + // value in the V1 op's attribute. + col_params->instance.impl_details.subdiv_offsets.push_back(0); + col_params->merge_op = merge_op_.get(); + col_params->final_op = final_op_.get(); VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size << " group_key " << col_params->group.group_key << " instance_key " << col_params->instance.instance_key; + auto done_with_cleanup = [col_params, done = std::move(done)]() { + delete col_params; + done(); + }; + // Allocate the output tensor, trying to reuse the input. Tensor* output = nullptr; OP_REQUIRES_OK_ASYNC( c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output), - done); + done_with_cleanup); col_params->instance.shape = input.shape(); - // Store the updated params in this OpKernel. - col_params_ = col_params; - // Resolve the collective params. // Schedule the `CompleteParamsAsync` call on a work queue that can handle // blocking work because it's not guaranteed that this call cannot block. - c->collective_executor()->RunClosure([c, done = std::move(done), col_params, - col_exec]() { + c->collective_executor()->RunClosure([c, + done = std::move(done_with_cleanup), + col_params, col_exec]() { VLOG(1) << "CollectiveReduceV2 CompleteParams for collective " << col_params->name << " device " << c->device()->name() << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->attributes(), col_params.get(), - c->cancellation_manager(), + c->device()->attributes(), col_params, c->cancellation_manager(), [c, done = std::move(done), col_params, col_exec](const Status& s) { if (s.ok()) { auto actual_done = [c, group_key = col_params->group.group_key, @@ -578,12 +589,22 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { } private: - std::shared_ptr col_params_; + string name_; + DataType data_type_ = DT_INVALID; + string communication_hint_; + float timeout_seconds_ = 0; + DeviceType device_type_; + std::unique_ptr merge_op_; + std::unique_ptr final_op_; }; REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU), CollectiveReduceV2OpKernel); -REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2") + .Device(DEVICE_GPU) + .HostMemory("group_size") + .HostMemory("group_key") + .HostMemory("instance_key"), CollectiveReduceV2OpKernel); class CollectiveGatherV2OpKernel : public AsyncOpKernel { @@ -650,15 +671,16 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel { 0, output_shape.dim_size(0) * col_params->group.group_size); col_params->instance.shape = output_shape; - Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - c, c->allocate_output(0, col_params->instance.shape, &output), done); - auto done_with_cleanup = [col_params, done = std::move(done)]() { delete col_params; done(); }; + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + c, c->allocate_output(0, col_params->instance.shape, &output), + done_with_cleanup); + // Resolve the collective params. // Schedule the `CompleteParamsAsync` call on a work queue that can handle // blocking work because it's not guaranteed that this call cannot block. @@ -704,16 +726,20 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel { } private: - DataType data_type_; - string communication_hint_; - float timeout_seconds_; - DeviceType device_type_; string name_; + DataType data_type_ = DT_INVALID; + string communication_hint_; + float timeout_seconds_ = 0; + DeviceType device_type_; }; REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU), CollectiveGatherV2OpKernel); -REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2") + .Device(DEVICE_GPU) + .HostMemory("group_size") + .HostMemory("group_key") + .HostMemory("instance_key"), CollectiveGatherV2OpKernel); } // namespace diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 025a8e37a94..f5b9e79fb54 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -619,19 +619,7 @@ template struct LaunchConv2DOp; int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { - const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); - if (workspace_limit_in_mb_str != nullptr && - strcmp(workspace_limit_in_mb_str, "") != 0) { - int64 scratch_limit_in_mb = -1; - if (strings::safe_strto64(workspace_limit_in_mb_str, - &scratch_limit_in_mb)) { - return scratch_limit_in_mb * (1 << 20); - } else { - LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " - << workspace_limit_in_mb_str; - } - } - return default_value_in_bytes; + return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); } // A dummy type to group forward convolution autotune results together. diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 2e97d486b54..8beab722a64 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -48,52 +48,7 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, // A class to provide scratch-space allocator for Stream-Executor Cudnn // callback. TensorFlow is responsible for releasing the temporary buffers after // the kernel finishes. -class DnnScratchAllocator : public se::ScratchAllocator { - public: - virtual ~DnnScratchAllocator() {} - DnnScratchAllocator(int64 memory_limit, OpKernelContext* context) - : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} - int64 GetMemoryLimitInBytes() override { return memory_limit_; } - se::port::StatusOr> AllocateBytes( - int64 byte_size) override { - Tensor temporary_memory; - if (byte_size < 0) { - return se::port::Status{se::port::error::INVALID_ARGUMENT, - "Requested negative byte size!"}; - } - if (byte_size > memory_limit_) { - return se::port::Status{se::port::error::UNAVAILABLE, - absl::StrCat("Requested memory size (", byte_size, - ") exceeds the max memory limit (", - memory_limit_, ").")}; - } - AllocationAttributes allocation_attr; - allocation_attr.retry_on_failure = false; - Status allocation_status(context_->allocate_temp( - DT_UINT8, TensorShape({byte_size}), &temporary_memory, - AllocatorAttributes(), allocation_attr)); - if (!allocation_status.ok()) { - return se::port::Status{ - se::port::error::UNAVAILABLE, - absl::StrCat("Failed to allocate the requested memory size (", - byte_size, ").")}; - } - // Hold the reference of the allocated tensors until the end of the - // allocator. - allocated_tensors_.push_back(temporary_memory); - total_byte_size_ += byte_size; - return se::port::StatusOr>( - AsDeviceMemory(temporary_memory.flat().data(), - temporary_memory.flat().size())); - } - int64 TotalByteSize() { return total_byte_size_; } - - private: - int64 memory_limit_; - int64 total_byte_size_; - OpKernelContext* context_; - std::vector allocated_tensors_; -}; +using DnnScratchAllocator = GpuScratchAllocator; // Encapsulate all the shape information that is used in both forward and // backward conv operations. diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 88651d7bfdc..5aa74bb7d3d 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1071,10 +1071,12 @@ struct safe_pow : base> { }; template -struct maximum : base> {}; +struct maximum + : base> {}; template -struct minimum : base> {}; +struct minimum + : base> {}; template struct igamma : base> {}; @@ -1097,9 +1099,7 @@ struct scalar_atan2_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& y, const Scalar& x) const { -#if GOOGLE_CUDA - return std::atan2(y, x); -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM return ::atan2(y, x); #else return std::atan2(y, x); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 9f351edf11a..29aaf4c4d3e 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -203,9 +203,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:testlib", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:function_ops", "//third_party/eigen3", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index c3ef8122ebb..92bc48159ad 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -424,6 +424,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core/kernels:ragged_tensor_variant", "//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:name_utils", "//tensorflow/core/kernels/data:parallel_map_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index d8176eb9499..bdea5724911 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -263,8 +263,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { }); } - while (results_.empty() && !job_finished_ && !cancelled_ && - status_.ok()) { + while (results_.empty() && + !(job_finished_ && num_running_worker_threads_ == 0) && + !cancelled_ && status_.ok()) { get_next_cv_.wait(l); } if (cancelled_) { @@ -370,6 +371,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { job_finished_ = job_finished; if (job_finished) { get_next_cv_.notify_all(); + worker_thread_cv_.notify_all(); return; } for (int i = 0; i < tasks_.size(); ++i) { @@ -416,6 +418,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); num_running_worker_threads_--; outstanding_requests_--; + get_next_cv_.notify_all(); }; worker_threads_.push_back(ctx->StartThread( "tf-data-service-task_thread", [this, done = std::move(done)]() { @@ -440,7 +443,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { worker_thread_cv_.notify_one(); } outstanding_requests_--; - while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable())) { + while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable()) && + !job_finished_) { if (VLOG_IS_ON(3)) { VLOG(3) << "Sleeping with results_.size=" << results_.size() << ", outstanding_requests_=" << outstanding_requests_ @@ -452,7 +456,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { worker_thread_cv_.wait(l); } outstanding_requests_++; - if (cancelled_) { + if (cancelled_ || job_finished_) { return; } // Search for a task to update. @@ -475,8 +479,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Status s = GetElement(task_to_process.get(), deadline_micros); if (!s.ok()) { mutex_lock l(mu_); - VLOG(1) << "Failed to get element for task " - << task_to_process->task_id << ": " << s; + VLOG(1) << "Failed to get element from worker " + << task_to_process->address << ": " << s; task_to_process->in_use = false; status_ = s; get_next_cv_.notify_all(); @@ -529,6 +533,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { (deadline_micros > deadline_with_backoff_micros) ? deadline_with_backoff_micros : deadline_micros; + VLOG(1) << "Failed to get an element from worker " << task->address + << ": " << s << ". Will retry in " + << (backoff_until - now_micros) << " microseconds"; Env::Default()->SleepForMicroseconds(backoff_until - now_micros); } diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index b9f58b99b0f..4d993d9462f 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -57,7 +57,6 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { GraphDef graph_def; OP_REQUIRES_OK( ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def)); - StripDevicePlacement(graph_def.mutable_library()); DataServiceDispatcherClient client(address, protocol); int64 dataset_id; diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 5cc72ba853e..717c6a9fa21 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -527,6 +527,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { "Failed to allocate memory for the batch of component ", i); } } + RecordBufferEnqueue(ctx.get(), result->output); result->output_allocated = true; return Status::OK(); } @@ -536,6 +537,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) { mutex_lock l(result->mu); + if (result->output_allocated) { + RecordBufferDequeue(ctx, result->output); + } if (result->num_elements == 0) { if (result->status.ok() || errors::IsOutOfRange(result->status)) { *end_of_sequence = true; diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index 16cf7fe6416..80f23bb5a0c 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h" #include "tensorflow/core/kernels/data/stats_utils.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -678,12 +679,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) { int output_index = dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]); - (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {}); - Tensor serialized_ragged = - Tensor(ctx->allocator({}), DT_VARIANT, {2}); - auto serialized_ragged_t = serialized_ragged.vec(); - serialized_ragged_t(0) = example_result.ragged_splits[d]; - serialized_ragged_t(1) = example_result.ragged_values[d]; + RaggedTensorVariant serialized_ragged; + serialized_ragged.append_splits(example_result.ragged_splits[d]); + serialized_ragged.set_values(example_result.ragged_values[d]); (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {}); Tensor& ragged_wrapper = (*output)[output_index]; ragged_wrapper.scalar()() = serialized_ragged; diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index e2cbe7d9dcc..7a65baaa680 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -417,7 +417,6 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { std::vector slices; slices.reserve(tensors_.size()); for (const auto& tensor : tensors_) { - Tensor slice = tensor.Slice(offset_, slice_end); slices.push_back(tensor.Slice(offset_, slice_end)); } slices_to_concatenate.push_back(std::move(slices)); @@ -452,8 +451,28 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { if (desired_batch_size == 0) { DCHECK_EQ(batch_size, 0); DCHECK_EQ(slices_to_concatenate.size(), 0); - for (const auto& dtype : dataset()->output_dtypes()) { - out_tensors->push_back(Tensor(dtype)); + for (int i = 0; i < dataset()->output_dtypes().size(); ++i) { + if (dataset()->output_shapes()[i].unknown_rank()) { + // For unknown rank tensors, we just create a empty Tensor since + // it doesn't matter what shape it is. + out_tensors->push_back(Tensor(dataset()->output_dtypes()[i])); + } else { + auto dim_sizes = dataset()->output_shapes()[i].dim_sizes(); + + // The output batch size is always zero since the desired batch + // size is zero. + dim_sizes[0] = 0; + + // Handle unknown dimensions by setting any unknown dimensions to + // zero since there isn't any data anyway. + for (int j = 1; j < dim_sizes.size(); ++j) { + if (dim_sizes[j] == -1) dim_sizes[j] = 0; + } + + TensorShape tensor_shape(dim_sizes); + out_tensors->push_back( + Tensor(dataset()->output_dtypes()[i], tensor_shape)); + } } return Status::OK(); } diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 6e079937885..2fcf8e95558 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -84,7 +84,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // of the Borg jobs, the experiments will be randomly turned on. // clang-format off absl::flat_hash_map live_experiments = { - {"enable_gradient_descent", 5} + {"enable_gradient_descent", 20} }; // clang-format on auto hash_func = [](const string& str) { return Hash64(str); }; diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 4a55514ffd1..96e3872b25b 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -369,6 +369,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } *out_tensors = std::move(buffer_.front().value); RecordBufferDequeue(ctx, *out_tensors); + } else { + // If status not ok, we still record the dequeue event to make sure each + // enqueue event is paired with a dequeue event even in the presence of + // errors. + RecordBufferDequeue(ctx, buffer_.front().value); } if (legacy_autotune_) { auto_tuner_.RecordConsumption(buffer_.size()); diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index 0a940e52eb7..d0be01578d8 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -222,7 +222,7 @@ class ParseExampleOp : public OpKernel { for (int d = 0; d < attrs_.num_sparse; ++d) { config.sparse.emplace_back(sparse_keys_t[d], attrs_.sparse_types[d]); } - config.sparse.reserve(attrs_.num_ragged); + config.ragged.reserve(attrs_.num_ragged); for (int d = 0; d < attrs_.num_ragged; ++d) { config.ragged.emplace_back(ragged_keys_t[d], attrs_.ragged_value_types[d], attrs_.ragged_split_types[d]); diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 050b83980c6..9b625c256a5 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -31,6 +31,7 @@ limitations under the License. #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -400,20 +401,7 @@ class CufftScratchAllocator : public se::ScratchAllocator { int64 GetCufftWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { - const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); - if (workspace_limit_in_mb_str != nullptr && - strcmp(workspace_limit_in_mb_str, "") != 0) { - int64 scratch_limit_in_mb = -1; - Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes, - &scratch_limit_in_mb); - if (!status.ok()) { - LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " - << workspace_limit_in_mb_str; - } else { - return scratch_limit_in_mb * (1 << 20); - } - } - return default_value_in_bytes; + return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); } class FFTGPUBase : public FFTBase { diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index 33c5df1ae23..5be032fb84d 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -35,13 +35,15 @@ namespace gpuprim = ::cub; #include "rocm/include/hipcub/hipcub.hpp" namespace gpuprim = ::hipcub; +// Required for sorting Eigen::half namespace rocprim { namespace detail { template <> struct radix_key_codec_base - : radix_key_codec_floating {}; + : radix_key_codec_floating {}; }; // namespace detail }; // namespace rocprim -#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 7da1963c676..1a14768f487 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/algorithm/container.h" #include "absl/base/call_once.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h" @@ -282,6 +283,64 @@ Status BestCudnnConvAlgorithm(absl::Span results, return Status::OK(); } +namespace gpu_utils { +int64 GetWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + if (strings::safe_strto64(workspace_limit_in_mb_str, + &scratch_limit_in_mb)) { + return scratch_limit_in_mb * (1 << 20); + } else { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } + } + return default_value_in_bytes; +} +} // namespace gpu_utils + +GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit, + OpKernelContext* context) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + +se::port::StatusOr> GpuScratchAllocator::AllocateBytes( + int64 byte_size) { + Tensor temporary_memory; + if (byte_size < 0) { + return se::port::Status{se::port::error::INVALID_ARGUMENT, + "Requested negative byte size!"}; + } + if (byte_size > memory_limit_) { + return se::port::Status{ + se::port::error::UNAVAILABLE, + absl::StrCat("Requested memory size (", byte_size, + ") exceeds the max memory limit (", memory_limit_, ").")}; + } + AllocationAttributes allocation_attr; + allocation_attr.retry_on_failure = false; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory, + AllocatorAttributes(), allocation_attr)); + if (!allocation_status.ok()) { + return se::port::Status{ + se::port::error::UNAVAILABLE, + absl::StrCat("Failed to allocate the requested memory size (", + byte_size, ").")}; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + // NOTE: We expect tensors to be deallocated when this allocator goes out of + // scope when allocated_tensors is destructed. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return se::port::StatusOr>( + AsDeviceMemory(temporary_memory.flat().data(), + temporary_memory.flat().size())); +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index a1589db3b5b..62db406513b 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -243,6 +243,37 @@ void LogFusedConvForwardAutotuneResults( Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo); +namespace gpu_utils { +// Get a workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64 GetWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes); +} // namespace gpu_utils + +// A class to provide scratch-space allocator for Stream-Executor callbacks in +// CUDA libraries (CUDNN etc.). +// TensorFlow is responsible for releasing the temporary buffers after +// the kernel finishes. +class GpuScratchAllocator : public se::ScratchAllocator { + public: + virtual ~GpuScratchAllocator() {} + + GpuScratchAllocator(int64 memory_limit, OpKernelContext* context); + + int64 GetMemoryLimitInBytes() override { return memory_limit_; } + + se::port::StatusOr> AllocateBytes( + int64 byte_size) override; + + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 memory_limit_; + int64 total_byte_size_; + OpKernelContext* context_; + std::vector allocated_tensors_; +}; } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index d60455df8fb..a94e98dc593 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -135,7 +135,7 @@ IMAGE_DEPS = [ "//tensorflow/core:lib_internal", "//tensorflow/core/lib/png:png_io", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:eigen_helpers", "//tensorflow/core/util/tensor_bundle", "//tensorflow/core/util:image_resizer_state", @@ -353,7 +353,10 @@ tf_cuda_cc_test( "resize_bilinear_op_test.cc", "resize_nearest_neighbor_op_test.cc", ], - tags = ["no_cuda_on_cpu_tap"], + tags = [ + "no_cuda_asan", # TODO(b/171334997): re-enable + "no_cuda_on_cpu_tap", + ], deps = [ ":image", ":sampling_kernels", @@ -387,7 +390,10 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "non_max_suppression_op_gpu_test", srcs = ["non_max_suppression_op_gpu_test.cc"], - tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"], + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171263349): re-enable. + "no_cuda_on_cpu_tap", + ], deps = [ ":image", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/kernels/image/crop_and_resize_op.cc b/tensorflow/core/kernels/image/crop_and_resize_op.cc index 5c196df9cfe..4efc4ae8846 100644 --- a/tensorflow/core/kernels/image/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/image/crop_and_resize_op.cc @@ -207,7 +207,7 @@ class CropAndResizeOp : public AsyncOpKernel { namespace functor { template struct CropAndResize { - bool operator()(const OpKernelContext* context, + bool operator()(OpKernelContext* context, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_index, @@ -222,6 +222,17 @@ struct CropAndResize { const int crop_width = crops.dimension(2); const int depth = crops.dimension(3); + // Since `functor::CropAndResize` operates on float, we first validate + // that we don't overflow (since overflow causes undefined behavior which + // could result in segfault in this scenario). + const Eigen::Tensor only_finite_elements = + boxes.isfinite().all(); + if (!only_finite_elements()) { + context->SetStatus(errors::InvalidArgument( + "Boxes contains at least one element that is not finite")); + return false; + } + // Sharding across boxes. auto CropAndResizePerBox = [&](int64 start_box, int64 limit_box) { for (int b = start_box; b < limit_box; ++b) { diff --git a/tensorflow/core/kernels/image/decode_image_op.cc b/tensorflow/core/kernels/image/decode_image_op.cc index 7b55ec35750..61b126fb81e 100644 --- a/tensorflow/core/kernels/image/decode_image_op.cc +++ b/tensorflow/core/kernels/image/decode_image_op.cc @@ -242,6 +242,16 @@ class DecodeImageV2Op : public OpKernel { flags.crop_x = crop_window_vec(1); flags.crop_height = crop_window_vec(2); flags.crop_width = crop_window_vec(3); + } else if (op_type_ == "DecodeBmp") { + // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here + // because currently `decode_(jpeg|png|gif)` ops can decode any one of + // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode + // anything but bmp formats. This behavior needs to be revisited. For more + // details, please refer to the bug. + OP_REQUIRES(context, false, + errors::InvalidArgument( + "Trying to decode JPEG format using DecodeBmp op. Use " + "`decode_jpeg` or `decode_image` instead.")); } // Output tensor and the image buffer size. @@ -346,6 +356,24 @@ class DecodeImageV2Op : public OpKernel { status = context->allocate_output( 0, TensorShape({height, width, decode.channels}), &output); } + + if (op_type_ == "DecodeBmp") { + // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here + // because currently `decode_(jpeg|png|gif)` ops can decode any one of + // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode + // anything but bmp formats. This behavior needs to be revisited. For more + // details, please refer to the bug. + OP_REQUIRES(context, false, + errors::InvalidArgument( + "Trying to decode PNG format using DecodeBmp op. Use " + "`decode_png` or `decode_image` instead.")); + } else if (op_type_ == "DecodeAndCropJpeg") { + OP_REQUIRES(context, false, + errors::InvalidArgument( + "DecodeAndCropJpeg operation can run on JPEG only, but " + "detected PNG.")); + } + if (!status.ok()) png::CommonFreeDecode(&decode); OP_REQUIRES_OK(context, status); @@ -393,6 +421,23 @@ class DecodeImageV2Op : public OpKernel { errors::InvalidArgument("channels must be 0 or 3 for GIF, got ", channels_)); + if (op_type_ == "DecodeBmp") { + // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here + // because currently `decode_(jpeg|png|gif)` ops can decode any one of + // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode + // anything but bmp formats. This behavior needs to be revisited. For more + // details, please refer to the bug. + OP_REQUIRES(context, false, + errors::InvalidArgument( + "Trying to decode GIF format using DecodeBmp op. Use " + "`decode_gif` or `decode_image` instead.")); + } else if (op_type_ == "DecodeAndCropJpeg") { + OP_REQUIRES(context, false, + errors::InvalidArgument( + "DecodeAndCropJpeg operation can run on JPEG only, but " + "detected GIF.")); + } + // Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor // allocation til after dtype conversion is done. `gif`::Decode` supports // uint8 only. @@ -477,6 +522,21 @@ class DecodeImageV2Op : public OpKernel { errors::InvalidArgument( "`channels` must be 0, 3 or 4 for BMP, but got ", channels_)); + if (op_type_ != "DecodeBmp" && op_type_ != "DecodeImage") { + if (op_type_ == "DecodeAndCropJpeg") { + OP_REQUIRES(context, false, + errors::InvalidArgument( + "DecodeAndCropJpeg operation can run on JPEG only, but " + "detected BMP.")); + } else { + OP_REQUIRES(context, false, + errors::InvalidArgument( + "Trying to decode BMP format using a wrong op. Use " + "`decode_bmp` or `decode_image` instead. Op used: ", + op_type_)); + } + } + OP_REQUIRES(context, (32 <= input.size()), errors::InvalidArgument("Incomplete bmp content, requires at " "least 32 bytes to find the header " diff --git a/tensorflow/core/kernels/in_topk_op_test.cc b/tensorflow/core/kernels/in_topk_op_test.cc index aacecb08bbe..9e4da735c5a 100644 --- a/tensorflow/core/kernels/in_topk_op_test.cc +++ b/tensorflow/core/kernels/in_topk_op_test.cc @@ -76,9 +76,9 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) { BM_InTopK(int64, 64, 1000, 10, cpu); BM_InTopK(int64, 64, 10000, 10, cpu); -#ifdef GOOGLE_CUDA +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) BM_InTopK(int64, 64, 1000, 10, gpu); BM_InTopK(int64, 64, 10000, 10, gpu); -#endif // GOOGLE_CUDA +#endif // defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) } // namespace tensorflow diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index b9b2d1f0eae..1fe9a34e67d 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -549,7 +549,7 @@ struct EinsumHelper { static Status ContractOperands(OpKernelContext* ctx, absl::Span inputs, absl::Span swap_free_and_contract, - Tensor* output) { + bool use_autotune, Tensor* output) { if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output); MatMulBCast bcast(inputs[0].shape().dim_sizes(), @@ -583,7 +583,7 @@ struct EinsumHelper { ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, /*adj_y=*/false, trans_x, trans_y, - bcast, &output_reshaped); + bcast, use_autotune, &output_reshaped); return Status::OK(); } }; @@ -598,6 +598,7 @@ class EinsumOp : public OpKernel { equation_, &input_labels_, &output_labels_, &label_types_, &input_label_counts_, &output_label_counts_, &input_has_ellipsis_, &output_has_ellipsis_)); + use_autotune_ = MatmulAutotuneEnable(); } void Compute(OpKernelContext* ctx) override { @@ -640,7 +641,7 @@ class EinsumOp : public OpKernel { Tensor contraction_output_reshaped; OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( ctx, inputs_reduced, swap_free_and_contract, - &contraction_output_reshaped)); + use_autotune_, &contraction_output_reshaped)); // Copy the batch labels from the contraction output. Recover the batch // shape, which may have been broadcasted. @@ -738,6 +739,7 @@ class EinsumOp : public OpKernel { LabelCounts output_label_counts_; gtl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; + bool use_autotune_; }; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/linalg/svd_op_gpu.cu.cc b/tensorflow/core/kernels/linalg/svd_op_gpu.cu.cc index 06d1efe6dd5..4fc41940e60 100644 --- a/tensorflow/core/kernels/linalg/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/linalg/svd_op_gpu.cu.cc @@ -343,9 +343,6 @@ class SvdOpGpu : public AsyncOpKernel { void ComputeAsync(OpKernelContext* context, DoneCallback done) final { const Tensor& input = context->input(0); const int ndims = input.dims(); - const int64 m = input.dim_size(ndims - 2); - const int64 n = input.dim_size(ndims - 1); - const int64 p = std::min(m, n); // Validate inputs. OP_REQUIRES_ASYNC( @@ -353,6 +350,10 @@ class SvdOpGpu : public AsyncOpKernel { errors::InvalidArgument("Input must have rank >= 2, got ", ndims), done); + const int64 m = input.dim_size(ndims - 2); + const int64 n = input.dim_size(ndims - 1); + const int64 p = std::min(m, n); + // output tensors. Tensor* outputU = NULL; Tensor* outputS = NULL; diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 6689e0fb200..297173cfca3 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -22,7 +22,7 @@ MKL_SHORT_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:ops_util", ] + mkl_deps() diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 5814eec0b76..b579f34b17a 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -941,23 +941,15 @@ class MklConvOp : public OpKernel { const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add); MklDnnShape add_mkl_shape; GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format); - if (native_format) { - // Forward the summand tensor to the output only if it has no other - // references, otherwise make a copy of it. - if (!context->forward_input_to_output_with_shape( - kInputIndex_Add, kOutputIndex_Dst, output_tf_shape, - output_tensor)) { - AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, - output_tf_shape, *output_mkl_shape, - native_format); - bool result = - (*output_tensor)->CopyFrom(add_tensor, add_tensor.shape()); - DCHECK(result); - } + // Forward the summand tensor to the output only if it has no other + // references, otherwise make a copy of it. + if (native_format && context->forward_input_to_output_with_shape( + kInputIndex_Add, kOutputIndex_Dst, + output_tf_shape, output_tensor)) { return; } // Check if reorder is needed - if (add_mkl_shape == *output_mkl_shape && + if (!native_format && add_mkl_shape == *output_mkl_shape && ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, kOutputIndex_Dst, output_tensor, add_mkl_shape, false)) { @@ -987,6 +979,13 @@ class MklConvOp : public OpKernel { const_cast(add_tensor.flat().data())); void* dst_buf = static_cast((*output_tensor)->flat().data()); + if (native_format) { + // We are simply deep copying the add_tensor to output_tensor without + // changing memory layout, hence using same memory descriptor. + ADD_MD = DST_MD = + memory::desc({add_tensor.NumElements()}, MklDnnType(), + mkldnn::memory::format_tag::x); + } fuse_add_src_.reset( new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf)); fuse_add_dst_.reset( diff --git a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h index 9fd699cf704..1624a00331a 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h +++ b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h @@ -74,7 +74,7 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a, #pragma omp parallel for #endif // !ENABLE_MKLDNN_THREADPOOL // TODO: Add eigen parallel_for - for (size_t n = 0; n < n_channel; ++n) { + for (int64_t n = 0; n < n_channel; ++n) { float a_float_for_one_quant_level = MklFloatForOneQuantizedLevel(min_a, max_a); float b_float_for_one_quant_level = diff --git a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc index 0cd4843c0d8..f6bc773de4f 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc @@ -77,10 +77,14 @@ class MklRequantizationRangePerChannelOp : public OpKernel { float out_min_max = std::numeric_limits::min(); #ifndef ENABLE_MKLDNN_THREADPOOL +#ifdef _MSC_VER +#pragma omp parallel for +#else #pragma omp parallel for reduction(max : out_min_max) +#endif #endif // !ENABLE_MKLDNN_THREADPOOL // TODO: Add eigen parallel_for - for (size_t i = 0; i < depth; ++i) { + for (int64_t i = 0; i < depth; ++i) { Eigen::Tensor min = transposed_input.chip<0>(i).minimum(); Eigen::Tensor max = diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 1d18e41d3ef..062f87e745f 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -92,7 +92,9 @@ tf_cuda_cc_test( name = "gpu_tanh_test", size = "small", srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_tanh_test.cc"]), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171341759): re-enable. + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -109,7 +111,9 @@ tf_cuda_cc_test( name = "gpu_abs_test", size = "small", srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171341759): re-enable. + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h index 457658948ed..5d34ac83fb5 100644 --- a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h @@ -97,13 +97,18 @@ Tensor ConvertDescriptorToTensor( auto result_desc = MLIR_FUNCTION(tf_op, mlir_type)(ctx, &input_desc); \ free(input_desc.descriptor); \ \ - tensorflow::AllocatorAttributes attrs; \ - auto* allocator = ctx->get_allocator(attrs); \ - \ - Tensor result_tensor = ConvertDescriptorToTensor( \ - result_desc, tf_data_type, allocator); \ + /* Compare data pointers to detect forwarding. */ \ + void* result_data_ptr = static_cast(result_desc.descriptor)[0]; \ + if (input.data() == result_data_ptr) { \ + ctx->set_output(0, input); \ + } else { \ + tensorflow::AllocatorAttributes attrs; \ + auto* allocator = ctx->get_allocator(attrs); \ + Tensor result_tensor = ConvertDescriptorToTensor( \ + result_desc, tf_data_type, allocator); \ + ctx->set_output(0, result_tensor); \ + } \ free(result_desc.descriptor); \ - ctx->set_output(0, result_tensor); \ } \ }; \ } \ diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc index dec0262cf04..675bdaec225 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc @@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(0); + OP_REQUIRES( + ctx, (axis_ == -1 || axis_ < input.shape().dims()), + errors::InvalidArgument("Shape must be at least rank ", axis_ + 1, + " but is rank ", input.shape().dims())); const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); Tensor input_min_tensor; Tensor input_max_tensor; diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index aa736ad7f60..d9993bb6d39 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -20,110 +20,76 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector nested_splits; -}; - -Status RaggedComponentsFromVariant(const Tensor& encoded_variant, - int ragged_rank, DataType value_dtype, - DataType split_dtype, - std::vector* decoded_ragged) { +Status RaggedComponentsFromVariant( + const Tensor& encoded_variant, int ragged_rank, DataType value_dtype, + DataType split_dtype, std::vector* decoded_ragged) { const auto& flat_variants = encoded_variant.flat(); - decoded_ragged->resize(flat_variants.size()); - // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the - // input. + decoded_ragged->reserve(flat_variants.size()); + for (int i = 0; i < flat_variants.size(); i++) { const auto& flat_variant = flat_variants(i); - const Tensor* encoded_list = flat_variant.get(); - if (encoded_list == nullptr) { + const RaggedTensorVariant* decoded = + flat_variant.get(); + if (decoded == nullptr) { return errors::InvalidArgument( "Input Variant element at index ", i, - " doesn't hold a Tensor: ", flat_variant.DebugString()); + " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString()); } - if (encoded_list->dims() != 1) { + decoded_ragged->push_back(*decoded); + decoded = &decoded_ragged->back(); + // Check ragged rank & types + if (decoded->ragged_rank() != ragged_rank) { return errors::InvalidArgument( - "Encoded input Variant must have rank 1, but found rank: ", - encoded_list->dims(), - ". encoded input Variant: ", encoded_list->DebugString()); + "Encoded input RaggedTensorVariant has ragged_rank=", + decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, "."); } - if (encoded_list->NumElements() != (ragged_rank + 1) && - encoded_list->NumElements() != 1) { - return errors::InvalidArgument( - "Encoded input Variant must hold either input_ragged_rank + 1 " - "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), " - "input_ragged_rank: ", - ragged_rank, - ", encoded input Variant: ", encoded_list->DebugString()); - } - const auto& input_vec = encoded_list->vec(); - - // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor - // to create the component RaggedTensors. - (*decoded_ragged)[i].nested_splits.reserve(ragged_rank); - for (int j = 0; j < ragged_rank; j++) { - const Tensor* split_tensor = input_vec(j).get(); - if (split_tensor == nullptr) { - return errors::InvalidArgument( - "Encoded scalar element at index ", i, - " doesn't have a splits Tensor at split_index ", j, ": ", - input_vec(j).DebugString()); - } - Tensor splits_tensor = *split_tensor; - if (splits_tensor.dtype() != split_dtype) { - return errors::InvalidArgument( - "Expected splits Tensor dtype: ", split_dtype, - ", found: ", splits_tensor.dtype()); - } - if (splits_tensor.dims() != 1) { - return errors::InvalidArgument( - "Ragged splits must have rank 1; encoded scalar element at index ", - i, " has splits Tensor at split_index ", j, ": ", - splits_tensor.DebugString()); - } - (*decoded_ragged)[i].nested_splits.push_back(splits_tensor); - } - const Tensor* values_tensor = input_vec(ragged_rank).get(); - if (values_tensor == nullptr) { - return errors::InvalidArgument("Encoded scalar element at index ", i, - " doesn't have a values Tensor: ", - input_vec(ragged_rank).DebugString()); - } - if (values_tensor->dtype() != value_dtype) { + if (decoded->values().dtype() != value_dtype) { return errors::InvalidArgument( "Expected values Tensor dtype: ", DataTypeString(value_dtype), - ", found: ", DataTypeString(values_tensor->dtype())); + ", found: ", DataTypeString(decoded->values().dtype())); } - if (values_tensor->dims() < 1) { + if (decoded->values().dims() < 1) { return errors::InvalidArgument( "Ragged values must have rank >= 1; encoded scalar element at index ", - i, " has values Tensor: ", values_tensor->DebugString()); + i, " has values Tensor: ", decoded->values().DebugString()); + } + for (const auto& splits : decoded->nested_splits()) { + if (splits.dtype() != split_dtype) { + return errors::InvalidArgument( + "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype), + ", found: ", DataTypeString(splits.dtype())); + } + if (splits.dims() != 1) { + return errors::InvalidArgument( + "Ragged splits must have rank 1; encoded scalar element at index ", + i, " has splits Tensor ", splits.DebugString()); + } } - (*decoded_ragged)[i].values = *values_tensor; } return Status::OK(); } template Status NestedStackRaggedTensors( - const std::vector& ragged_components, + const std::vector& ragged_components, const std::vector& nested_dim_sizes, const int input_ragged_rank, - const int output_ragged_rank, RaggedTensor* output_ragged) { - output_ragged->nested_splits.reserve(output_ragged_rank); + const int output_ragged_rank, RaggedTensorVariant* output_ragged) { + output_ragged->mutable_nested_splits()->reserve(output_ragged_rank); const int dims = nested_dim_sizes.size(); // Populate first `dims - 1` splits. for (int i = 0; i < dims - 1; i++) { int dims_splits_size = nested_dim_sizes[i] + 1; - output_ragged->nested_splits.push_back(Tensor( - DataTypeToEnum::value, TensorShape({dims_splits_size}))); - auto splits_vec = output_ragged->nested_splits[i].vec(); + output_ragged->append_splits(Tensor(DataTypeToEnum::value, + TensorShape({dims_splits_size}))); + auto splits_vec = output_ragged->mutable_splits(i)->vec(); int split_diff = nested_dim_sizes[i + 1]; for (int j = 0; j < dims_splits_size; j++) { splits_vec(j) = j * split_diff; @@ -132,15 +98,15 @@ Status NestedStackRaggedTensors( // Populate `dims`-th split. int splits_size = ragged_components.size() + 1; - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum::value, TensorShape({splits_size}))); auto dims_splits_vec = - output_ragged->nested_splits[dims - 1].vec(); + output_ragged->mutable_splits(dims - 1)->vec(); dims_splits_vec(0) = 0; for (int i = 0; i < ragged_components.size(); i++) { - int split_val = ragged_components[i].values.shape().dim_size(0); - if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) { - split_val = ragged_components[i].nested_splits[0].NumElements() - 1; + int split_val = ragged_components[i].values().shape().dim_size(0); + if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) { + split_val = ragged_components[i].splits(0).NumElements() - 1; } dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val; } @@ -150,24 +116,24 @@ Status NestedStackRaggedTensors( int split_index = dims + i; int split_size = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (!ragged_components[j].nested_splits.empty()) { - split_size += ragged_components[j].nested_splits[i].NumElements() - 1; + if (!ragged_components[j].nested_splits().empty()) { + split_size += ragged_components[j].splits(i).NumElements() - 1; } } - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum::value, TensorShape({split_size}))); auto splits_vec = - output_ragged->nested_splits[split_index].vec(); + output_ragged->mutable_splits(split_index)->vec(); splits_vec(0) = 0; SPLIT_TYPE last_split_value = 0; int index = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (ragged_components[j].nested_splits.empty()) { + if (ragged_components[j].nested_splits().empty()) { // Corner case: empty row. e.g [ [[x], [x]], [] ] continue; } auto component_splits_vec = - ragged_components[j].nested_splits[i].vec(); + ragged_components[j].splits(i).vec(); for (int k = 1; k < component_splits_vec.size(); k++, index++) { splits_vec(index) = component_splits_vec(k) + last_split_value; } @@ -187,35 +153,35 @@ Status NestedStackRaggedTensors( if (ragged_components.empty()) { component_values_shape = TensorShape({0}); } else { - component_values_shape = ragged_components[0].values.shape(); + component_values_shape = ragged_components[0].values().shape(); } // Populate values. int values_size = component_values_shape.dim_size(0); for (int i = 1; i < ragged_components.size(); i++) { - if (ragged_components[i].values.dims() != component_values_shape.dims()) { + if (ragged_components[i].values().dims() != component_values_shape.dims()) { return errors::InvalidArgument( "Rank of values must match for all " "components; values shape at index 0: ", component_values_shape.DebugString(), ", values shape at index ", i, - ": ", ragged_components[i].values.shape().DebugString()); + ": ", ragged_components[i].values().shape().DebugString()); } - values_size += ragged_components[i].values.shape().dim_size(0); + values_size += ragged_components[i].values().shape().dim_size(0); } component_values_shape.set_dim(0, values_size); - output_ragged->values = - Tensor(DataTypeToEnum::value, component_values_shape); + output_ragged->set_values( + Tensor(DataTypeToEnum::value, component_values_shape)); auto output_values_flat = - output_ragged->values.flat_outer_dims(); + output_ragged->mutable_values()->flat_outer_dims(); int values_index = 0; for (int i = 0; i < ragged_components.size(); i++) { auto component_values_flat = - ragged_components[i].values.flat_outer_dims(); - int num_inner_elements = ragged_components[i].values.NumElements(); - if (ragged_components[i].values.dim_size(0) > 0) { - num_inner_elements /= ragged_components[i].values.dim_size(0); + ragged_components[i].values().flat_outer_dims(); + int num_inner_elements = ragged_components[i].values().NumElements(); + if (ragged_components[i].values().dim_size(0) > 0) { + num_inner_elements /= ragged_components[i].values().dim_size(0); } - for (int j = 0; j < ragged_components[i].values.dim_size(0); + for (int j = 0; j < ragged_components[i].values().dim_size(0); j++, values_index++) { for (int k = 0; k < num_inner_elements; k++) { output_values_flat(values_index, k) = component_values_flat(j, k); @@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel { // Decode all variants. const auto value_dtype = DataTypeToEnum::v(); const auto split_dtype = DataTypeToEnum::v(); - std::vector decoded_components; + std::vector decoded_components; OP_REQUIRES_OK(context, RaggedComponentsFromVariant( encoded_variant, input_ragged_rank_, value_dtype, split_dtype, &decoded_components)); @@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel { for (int i = 0; i < encoded_variant.dims(); i++) { encoded_dim_sizes[i] = encoded_variant.dim_size(i); } - RaggedTensor output_ragged; + RaggedTensorVariant output_ragged; OP_REQUIRES_OK( context, NestedStackRaggedTensors( decoded_components, encoded_dim_sizes, input_ragged_rank_, @@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel { int output_ragged_rank_; void ReturnRaggedTensor(OpKernelContext* context, - RaggedTensor ragged_tensor) { - int ragged_rank = ragged_tensor.nested_splits.size(); + const RaggedTensorVariant& ragged_tensor) { + int ragged_rank = ragged_tensor.ragged_rank(); OpOutputList splits_out; OP_REQUIRES_OK(context, context->output_list("output_nested_splits", &splits_out)); for (int i = 0; i < ragged_rank; i++) { - splits_out.set(i, ragged_tensor.nested_splits[i]); + splits_out.set(i, ragged_tensor.splits(i)); } - context->set_output(ragged_rank, ragged_tensor.values); + context->set_output(ragged_rank, ragged_tensor.values()); } }; diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc index bdf321d0515..fc46283c90e 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -55,28 +56,22 @@ class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase { } template - Tensor CreateVariantFromRagged( + RaggedTensorVariant CreateVariantFromRagged( const std::vector>& ragged_splits, const TensorShape& ragged_values_shape, const std::vector& ragged_values) { - // Step 1: Create Tensors out of ragged splits and values. - std::vector ragged_components; + RaggedTensorVariant encoded; for (auto ragged_split : ragged_splits) { int splits_size = ragged_split.size(); Tensor splits(DataTypeToEnum::v(), TensorShape({splits_size})); test::FillValues(&splits, ragged_split); - ragged_components.push_back(splits); + encoded.append_splits(splits); } Tensor values(DataTypeToEnum::v(), ragged_values_shape); test::FillValues(&values, ragged_values); - ragged_components.push_back(values); - - // Step 2: Encode into a 1-D Variant Tensor. - int num_splits = ragged_splits.size(); - Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1})); - test::FillValues(&encoded_list, ragged_components); - return encoded_list; + encoded.set_values(values); + return encoded; } }; @@ -85,7 +80,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) { const std::vector split_2 = {0, 1, 2, 5, 6, 7}; const std::vector values = {0, 1, 1, 2, 2, 3, 4}; - Tensor encoded_variant = CreateVariantFromRagged( + auto encoded_variant = CreateVariantFromRagged( {split_1, split_2}, TensorShape({7}), values); Tensor expected_splits_1(DT_INT64, TensorShape({6})); Tensor expected_splits_2(DT_INT64, TensorShape({6})); @@ -113,7 +108,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) { const std::vector values = {0, 1, 1, 2, 2, 3, 4}; const std::vector batched_splits_1 = {0, 5}; - Tensor encoded_variant = CreateVariantFromRagged( + auto encoded_variant = CreateVariantFromRagged( {split_1, split_2}, TensorShape({7}), values); Tensor expected_splits_1(DT_INT64, TensorShape({2})); Tensor expected_splits_2(DT_INT64, TensorShape({6})); @@ -157,13 +152,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) { const std::vector batched_splits_2 = {0, 3, 3, 5, 6}; const std::vector batched_values = {1, 2, 3, 4, 5, 6}; - Tensor component_variant_1 = + auto component_variant_1 = CreateVariantFromRagged({}, TensorShape({3}), values_1); - Tensor component_variant_2 = + auto component_variant_2 = CreateVariantFromRagged({}, TensorShape({0}), values_2); - Tensor component_variant_3 = + auto component_variant_3 = CreateVariantFromRagged({}, TensorShape({2}), values_3); - Tensor component_variant_4 = + auto component_variant_4 = CreateVariantFromRagged({}, TensorShape({1}), values_4); Tensor expected_splits_1(DT_INT64, TensorShape({3})); @@ -223,15 +218,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) { test::FillValues(&expected_splits_3, batched_splits_3); test::FillValues(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({2}), component_values_2); - Tensor variant_component_3 = CreateVariantFromRagged( + auto variant_component_3 = CreateVariantFromRagged( {component_split_3_1}, TensorShape({2}), component_values_3); - Tensor variant_component_4 = CreateVariantFromRagged( + auto variant_component_4 = CreateVariantFromRagged( {component_split_4_1}, TensorShape({3}), component_values_4); - Tensor variant_component_5 = CreateVariantFromRagged( + auto variant_component_5 = CreateVariantFromRagged( {component_split_5_1}, TensorShape({3}), component_values_5); int input_ragged_rank = 1; int output_ragged_rank = 3; @@ -297,10 +292,10 @@ TEST_F(RaggedTensorFromVariantKernelTest, test::FillValues(&expected_splits_4, batched_splits_4); test::FillValues(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1, component_split_1_2}, TensorShape({11}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1, component_split_2_2}, TensorShape({11}), component_values_2); int input_ragged_rank = -1; @@ -336,9 +331,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) { test::FillValues(&expected_splits_2, batched_splits_2); test::FillValues(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({3}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({0}), {}); // Empty row. int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -371,9 +366,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) { test::FillValues(&expected_splits_2, batched_splits_2); test::FillValues(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1, 2}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({2, 2}), component_values_2); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -423,15 +418,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) { test::FillValues(&expected_splits_3, batched_splits_3); test::FillValues(&expected_values, batched_values); - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({2}), component_values_2); - Tensor variant_component_3 = CreateVariantFromRagged( + auto variant_component_3 = CreateVariantFromRagged( {component_split_3_1}, TensorShape({2}), component_values_3); - Tensor variant_component_4 = CreateVariantFromRagged( + auto variant_component_4 = CreateVariantFromRagged( {component_split_4_1}, TensorShape({3}), component_values_4); - Tensor variant_component_5 = CreateVariantFromRagged( + auto variant_component_5 = CreateVariantFromRagged( {component_split_5_1}, TensorShape({3}), component_values_5); int input_ragged_rank = 1; int output_ragged_rank = 3; @@ -451,13 +446,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) { // Tests for invalid inputs. TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) { - Tensor component_variant_1 = + auto component_variant_1 = CreateVariantFromRagged({}, TensorShape({3}), {1, 2, 3}); - Tensor component_variant_2 = + auto component_variant_2 = CreateVariantFromRagged({}, TensorShape({0}), {}); - Tensor component_variant_3 = + auto component_variant_3 = CreateVariantFromRagged({}, TensorShape({2}), {1, 2}); - Tensor component_variant_4 = + auto component_variant_4 = CreateVariantFromRagged({}, TensorShape({1}), {1}); int input_ragged_rank = -1; @@ -478,9 +473,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) { const std::vector component_values_1 = {0}; const std::vector component_values_2 = {0, 1}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({2}), component_values_2); int input_ragged_rank = 1; @@ -493,33 +488,21 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) { "input_ragged_rank + encoded_ragged.dims()")); } -TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) { +TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldRaggedTensorVariant) { int input_ragged_rank = 1; int output_ragged_rank = 2; BuildDecodeRaggedTensorGraph( input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2}); EXPECT_TRUE(absl::StartsWith( RunOpKernel().error_message(), - "Input Variant element at index 0 doesn't hold a Tensor")); -} - -TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) { - Tensor variant_list(DT_VARIANT, TensorShape({2, 1})); - test::FillValues(&variant_list, {1, 2}); - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); - EXPECT_TRUE(absl::StartsWith( - RunOpKernel().error_message(), - "Encoded input Variant must have rank 1, but found rank: 2")); + "Input Variant element at index 0 doesn't hold a RaggedTensorVariant")); } TEST_F(RaggedTensorFromVariantKernelTest, InputScalarElementDoesNotMatchInputRaggedRank) { const std::vector component_split_1_1 = {0, 1}; const std::vector component_values_1 = {1, 2}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1, 2}), component_values_1); int input_ragged_rank = 2; @@ -527,31 +510,17 @@ TEST_F(RaggedTensorFromVariantKernelTest, BuildDecodeRaggedTensorGraph(input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_component_1}); - EXPECT_TRUE(absl::StartsWith( - RunOpKernel().error_message(), - "Encoded input Variant must hold either input_ragged_rank + 1 " - "Tensors or an empty Tensor")); -} - -TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) { - Tensor variant_list(DT_VARIANT, TensorShape({2})); - test::FillValues(&variant_list, {1, 2}); - - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph(input_ragged_rank, output_ragged_rank, - TensorShape({1}), {variant_list}); EXPECT_TRUE( absl::StartsWith(RunOpKernel().error_message(), - "Encoded scalar element at index 0 doesn't have a " - "splits Tensor at split_index 0")); + "Encoded input RaggedTensorVariant has ragged_rank=1. " + "Expected ragged_rank=2.")); } TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) { const std::vector component_split_1_1 = {0, 1}; const std::vector component_values_1 = {0}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); int input_ragged_rank = 1; @@ -559,46 +528,29 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) { BuildDecodeRaggedTensorGraph(input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_component_1}); - EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), - "Expected splits Tensor dtype: 3, found: 9")); + EXPECT_TRUE(absl::StartsWith( + RunOpKernel().error_message(), + "Expected row_splits Tensor dtype: int32, found: int64")); } TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) { - Tensor splits(DT_INT64, TensorShape({2, 1})); - test::FillValues(&splits, {1, 2}); - Tensor values(DT_INT32, {2}); - test::FillValues(&values, {1, 2}); - Tensor encoded_list(DT_VARIANT, TensorShape({2})); - test::FillValues(&encoded_list, {splits, values}); + RaggedTensorVariant encoded(Tensor(DT_INT32, {2}), + {Tensor(DT_INT64, {2, 1})}); + test::FillValues(encoded.mutable_splits(0), {1, 2}); + test::FillValues(encoded.mutable_values(), {1, 2}); int input_ragged_rank = 1; int output_ragged_rank = 2; BuildDecodeRaggedTensorGraph( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list}); + input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded}); EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), "Ragged splits must have rank 1")); } -TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) { - Tensor splits(DT_INT64, TensorShape({3})); - test::FillValues(&splits, {0, 2, 3}); - Tensor variant_list(DT_VARIANT, TensorShape({2})); - test::FillValues(&variant_list, {splits, 2}); - - int input_ragged_rank = 1; - int output_ragged_rank = 2; - BuildDecodeRaggedTensorGraph( - input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); - EXPECT_TRUE( - absl::StartsWith(RunOpKernel().error_message(), - "Encoded scalar element at index 0 doesn't have a " - "values Tensor")); -} - TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) { const std::vector component_split_1_1 = {0, 1}; const std::vector component_values_1 = {0}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -611,7 +563,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) { } TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) { - Tensor variant_component_1 = + auto variant_component_1 = CreateVariantFromRagged({{0, 1}}, TensorShape({}), {1}); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -628,9 +580,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) { const std::vector component_values_1 = {0}; const std::vector component_values_2 = {0, 1, 2, 3}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {component_split_1_1}, TensorShape({1}), component_values_1); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {component_split_2_1}, TensorShape({2, 2}), component_values_2); int input_ragged_rank = 1; int output_ragged_rank = 2; @@ -711,13 +663,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) { const std::vector batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5}; - Tensor variant_component_1 = CreateVariantFromRagged( + auto variant_component_1 = CreateVariantFromRagged( {}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2}); - Tensor variant_component_2 = CreateVariantFromRagged( + auto variant_component_2 = CreateVariantFromRagged( {}, TensorShape({1, 2, 2}), {3, 3, 3, 3}); - Tensor variant_component_3 = + auto variant_component_3 = CreateVariantFromRagged({}, TensorShape({0, 2, 2}), {}); - Tensor variant_component_4 = CreateVariantFromRagged( + auto variant_component_4 = CreateVariantFromRagged( {}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5}); Tensor expected_splits_1(DT_INT64, TensorShape({5})); diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 64c372b005e..549dc68dfbf 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -18,50 +18,38 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/tensor_ops_util.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector nested_splits; -}; - -Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) { - // Encode as a rank-1 Variant Tensor. - int ragged_rank = ragged.nested_splits.size(); - *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1})); - auto encoded_vec = encoded_list->vec(); - for (int i = 0; i < ragged_rank; i++) { - encoded_vec(i) = ragged.nested_splits[i]; - } - encoded_vec(ragged_rank) = ragged.values; - return Status::OK(); -} - template -Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, - std::vector* ragged_components) { +Status UnbatchRaggedZerothDim( + const RaggedTensorVariant& batched_ragged, + std::vector* ragged_components) { // Set up the component Ragged Tensors. - int ragged_rank = batched_ragged.nested_splits.size(); - auto batched_splits_top_vec = - batched_ragged.nested_splits[0].vec(); + int ragged_rank = batched_ragged.ragged_rank(); + auto batched_splits_top_vec = batched_ragged.splits(0).vec(); int num_components = batched_splits_top_vec.size() - 1; int num_splits = ragged_rank - 1; ragged_components->resize(num_components); - for (RaggedTensor ragged_component : *ragged_components) { - ragged_component.nested_splits.reserve(num_splits); + for (RaggedTensorVariant& ragged_component : *ragged_components) { + ragged_component.mutable_nested_splits()->reserve(num_splits); } - const auto& batched_flat = batched_ragged.values.flat(); - int num_inner_elems = batched_ragged.values.NumElements(); - if (batched_ragged.values.dim_size(0) > 1) { - num_inner_elems /= batched_ragged.values.dim_size(0); + const auto& batched_flat = batched_ragged.values().flat(); + int num_inner_elems = batched_ragged.values().NumElements(); + if (batched_ragged.values().dim_size(0) > 1) { + num_inner_elems /= batched_ragged.values().dim_size(0); } - TensorShape values_shape = batched_ragged.values.shape(); + TensorShape values_shape = batched_ragged.values().shape(); // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]] if (num_splits == 0) { @@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int limit = batched_splits_top_vec(i + 1); int num_values = limit - start; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat(); + (*ragged_components)[i].mutable_values()->flat(); for (int j = 0; j < num_values * num_inner_elems; j++) { ragged_component_values_flat(j) = batched_flat(j + start * num_inner_elems); @@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, std::vector::ConstVec> batched_splits_vec; batched_splits_vec.reserve(ragged_rank); for (int i = 0; i < ragged_rank; i++) { - batched_splits_vec.push_back( - batched_ragged.nested_splits[i].vec()); + batched_splits_vec.push_back(batched_ragged.splits(i).vec()); } std::vector index(num_splits, 1); std::vector ragged_component_values_size(num_components, 0); @@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int last_index = ragged_component_splits_vec[j - 1].size() - 1; split_size = ragged_component_splits_vec[j - 1](last_index) + 1; } - (*ragged_components)[i].nested_splits.push_back( + (*ragged_components)[i].append_splits( Tensor(DataTypeToEnum::value, TensorShape({split_size}))); ragged_component_splits_vec.push_back( - (*ragged_components)[i].nested_splits[j].vec()); + (*ragged_components)[i].mutable_splits(j)->vec()); SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1); ragged_component_splits_vec[j](0) = 0; for (int k = 1; k < split_size; k++, index[j]++) { @@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, for (int i = 0; i < num_components; i++) { int num_values = ragged_component_values_size[i]; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat(); + (*ragged_components)[i].mutable_values()->flat(); for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) { ragged_component_values_flat(j) = batched_flat(value_index); } @@ -152,46 +139,38 @@ class RaggedTensorToVariantOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", &ragged_nested_splits_in)); const int ragged_nested_splits_len = ragged_nested_splits_in.size(); - RaggedTensor batched_ragged_input; + RaggedTensorVariant batched_ragged_input; // Read ragged_values input. - batched_ragged_input.values = context->input(ragged_nested_splits_len); - batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len); + batched_ragged_input.set_values(context->input(ragged_nested_splits_len)); + batched_ragged_input.mutable_nested_splits()->reserve( + ragged_nested_splits_len); for (int i = 0; i < ragged_nested_splits_len; i++) { - batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]); + batched_ragged_input.append_splits(ragged_nested_splits_in[i]); } if (!batched_input_) { - // Encode the input as is. - Tensor encoded_list; - OP_REQUIRES_OK(context, - RaggedToVariant(batched_ragged_input, &encoded_list)); // Encode as a Scalar Variant Tensor. Tensor* encoded_scalar; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &encoded_scalar)); - encoded_scalar->scalar()() = std::move(encoded_list); + encoded_scalar->scalar()() = std::move(batched_ragged_input); return; } // Unbatch the Ragged Tensor and encode the components. - std::vector ragged_components; + std::vector unbatched_ragged_input; OP_REQUIRES_OK(context, UnbatchRaggedZerothDim( - batched_ragged_input, &ragged_components)); - std::vector encoded_components(ragged_components.size()); - for (int i = 0; i < ragged_components.size(); i++) { - OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i], - &encoded_components[i])); - } + batched_ragged_input, &unbatched_ragged_input)); // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor. - Tensor* encoded_ragged; - int output_size = ragged_components.size(); + Tensor* encoded_vector; + int output_size = unbatched_ragged_input.size(); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({output_size}), - &encoded_ragged)); - auto encoded_ragged_vec = encoded_ragged->vec(); + &encoded_vector)); + auto encoded_vector_t = encoded_vector->vec(); for (int i = 0; i < output_size; i++) { - encoded_ragged_vec(i) = encoded_components[i]; + encoded_vector_t(i) = unbatched_ragged_input[i]; } } @@ -199,12 +178,81 @@ class RaggedTensorToVariantOp : public OpKernel { bool batched_input_; }; -#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ - REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tvalues") \ - .TypeConstraint("Tsplits"), \ - RaggedTensorToVariantOp); +template +class RaggedTensorToVariantGradientOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // Read inputs. + Tensor encoded_variant = context->input(0); + Tensor row_splits = context->input(1); + auto flat_row_splits = row_splits.flat(); + TensorShape dense_values_shape; + OP_REQUIRES_OK(context, + TensorShapeUtils::MakeShape(context->input(2).vec(), + &dense_values_shape)); + + const auto& flat_variants = encoded_variant.flat(); + + // Get a Tensor containing the flat_values for each variant. + std::vector values; + for (int i = 0; i < flat_variants.size(); ++i) { + if (const auto* encoded = flat_variants(i).get()) { + values.push_back(encoded->values()); + } else { + // Missing value: this happens if only some of the variant values + // generated by ragged_tensor_to_variant impacted the value that we're + // calculating the gradient for. In this case, we will see a + // default-constructed variant; so treat it as a zero tensor with the + // appropriate shape. + const auto value_dtype = DataTypeToEnum::v(); + int piece_size = flat_row_splits(i + 1) - flat_row_splits(i); + TensorShape zeros_shape = dense_values_shape; + zeros_shape.set_dim(0, piece_size); + Tensor zero(value_dtype, zeros_shape); + zero.flat() = + zero.flat().constant(VALUE_TYPE()); + values.push_back(zero); + } + } + + if (values.size() == 1) { + // Just one flat_value tensor: return as-is. + context->set_output(0, values[0]); + } else { + // Multiple flat_values tensors: concatenate them together. + using Piece = typename TTypes::Matrix; + using ConstPiece = typename TTypes::ConstMatrix; + std::vector> pieces; + pieces.reserve(values.size()); + for (const Tensor& t : values) { + pieces.emplace_back( + new ConstPiece(t.shaped({1, t.NumElements()}))); + } + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, dense_values_shape, &out)); + Piece out_flat = + out->shaped({1, dense_values_shape.num_elements()}); + ConcatCPU(context->device(), pieces, &out_flat); + } + } +}; + +#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorToVariantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RaggedTensorToVariantGradient") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorToVariantGradientOp); + #define REGISTER_KERNELS(value_type) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64) diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc index c1438dd7af9..94f35673c8b 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -60,6 +61,43 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase { } AddInputFromArray(ragged_values_shape, ragged_values); } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values) { + RaggedTensorVariant encoded; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum::v(), + TensorShape({splits_size})); + test::FillValues(&splits, ragged_split); + encoded.append_splits(splits); + } + Tensor values(DataTypeToEnum::v(), ragged_values_shape); + test::FillValues(&values, ragged_values); + encoded.set_values(values); + return encoded; + } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const std::vector& ragged_values) { + int num_values = ragged_values.size(); + return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values); + } + + template + void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected, + const RaggedTensorVariant& actual) { + test::ExpectTensorEqual(actual.values(), expected.values()); + EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank()); + for (int i = 0; i < actual.ragged_rank(); ++i) { + test::ExpectTensorEqual(actual.splits(i), expected.splits(i)); + } + } }; TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { @@ -67,18 +105,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { const std::vector batched_splits_1 = {0, 2, 3, 3}; const std::vector batched_splits_2 = {0, 0, 0, 0}; - const std::vector component_splits_1_1 = {0, 0, 0}; - const std::vector component_splits_2_1 = {0, 0}; - const std::vector component_splits_3_1 = {0}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({3})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); - Tensor expected_splits_3_1(DT_INT64, TensorShape({1})); - - test::FillValues(&expected_splits_1_1, component_splits_1_1); - test::FillValues(&expected_splits_2_1, component_splits_2_1); - test::FillValues(&expected_splits_3_1, component_splits_3_1); - BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, TensorShape({0}), {}, true); TF_ASSERT_OK(RunOpKernel()); @@ -86,55 +112,26 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 3); - const Variant& encoded_splits_1_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(1); - const Variant& encoded_splits_3_1 = - encoded_list(2).get()->vec()(0); - const Variant& encoded_values_3 = - encoded_list(2).get()->vec()(1); - - test::ExpectTensorEqual(*encoded_splits_1_1.get(), - expected_splits_1_1); - test::ExpectTensorEqual(*encoded_splits_2_1.get(), - expected_splits_2_1); - test::ExpectTensorEqual(*encoded_splits_3_1.get(), - expected_splits_3_1); - test::ExpectTensorEqual(*encoded_values_1.get(), - Tensor(DT_INT32, TensorShape({0}))); - test::ExpectTensorEqual(*encoded_values_2.get(), - Tensor(DT_INT32, TensorShape({0}))); - test::ExpectTensorEqual(*encoded_values_3.get(), - Tensor(DT_INT32, TensorShape({0}))); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 0, 0}}, {}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 0}}, {}), + *encoded_list(1).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0}}, {}), + *encoded_list(2).get()); } TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) { // ragged_tensor= - // [ [x, x, x], + // [ [1, 2, 3], // [ ], - // [x, x ], - // [x ]] + // [4, 5 ], + // [6 ]] const std::vector batched_splits = {0, 3, 3, 5, 6}; const std::vector batched_values = {1, 2, 3, 4, 5, 6}; - const std::vector component_values_1 = {1, 2, 3}; - const std::vector component_values_3 = {4, 5}; - const std::vector component_values_4 = {6}; - - Tensor expected_values_1(DT_INT32, TensorShape({3})); - Tensor expected_values_2(DT_INT32, TensorShape({0})); - Tensor expected_values_3(DT_INT32, TensorShape({2})); - Tensor expected_values_4(DT_INT32, TensorShape({1})); - - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_3, component_values_3); - test::FillValues(&expected_values_4, component_values_4); - BuildEncodeRaggedTensorGraph({batched_splits}, TensorShape({6}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -142,45 +139,28 @@ TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 4); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_values_3 = - encoded_list(2).get()->vec()(0); - const Variant& encoded_values_4 = - encoded_list(3).get()->vec()(0); - - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); - test::ExpectTensorEqual(*encoded_values_3.get(), - expected_values_3); - test::ExpectTensorEqual(*encoded_values_4.get(), - expected_values_4); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {1, 2, 3}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {}), + *encoded_list(1).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {4, 5}), + *encoded_list(2).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {6}), + *encoded_list(3).get()); } TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) { // ragged_tensor= - // [[x, x], - // [x, x], - // [x, x]] + // [[1, 2], + // [4, 5], + // [6, 7]] const std::vector batched_splits = {0, 1, 2, 3}; const std::vector batched_values = {1, 2, 4, 5, 6, 7}; - const std::vector component_values_1 = {1, 2}; - const std::vector component_values_2 = {4, 5}; - const std::vector component_values_3 = {6, 7}; - - Tensor expected_values_1(DT_INT32, TensorShape({1, 2})); - Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); - Tensor expected_values_3(DT_INT32, TensorShape({1, 2})); - - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_2, component_values_2); - test::FillValues(&expected_values_3, component_values_3); - BuildEncodeRaggedTensorGraph( {batched_splits}, TensorShape({3, 2}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -188,44 +168,25 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 3); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_values_3 = - encoded_list(2).get()->vec()(0); - - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); - test::ExpectTensorEqual(*encoded_values_3.get(), - expected_values_3); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {1, 2}, {1, 2}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {1, 2}, {4, 5}), + *encoded_list(1).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, {1, 2}, {6, 7}), + *encoded_list(2).get()); } TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) { - // ragged_tensor=[ - // [ [[x, x], [x, x]], - // [[x, x] ] ] + // ragged_tensor= + // [ [[[1, 2], [4, 5]]], + // [[[6 7]]] ] const std::vector batched_splits_1 = {0, 1, 2}; const std::vector batched_splits_2 = {0, 2, 3}; const std::vector batched_values = {1, 2, 4, 5, 6, 7}; - const std::vector component_splits_1_1 = {0, 2}; - const std::vector component_splits_2_1 = {0, 1}; - const std::vector component_values_1 = {1, 2, 4, 5}; - const std::vector component_values_2 = {6, 7}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({2})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); - Tensor expected_values_1(DT_INT32, TensorShape({2, 2})); - Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); - - test::FillValues(&expected_splits_1_1, component_splits_1_1); - test::FillValues(&expected_splits_2_1, component_splits_2_1); - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_2, component_values_2); - BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, TensorShape({3, 2}), batched_values, true); @@ -234,23 +195,12 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(1); - - test::ExpectTensorEqual(*encoded_splits_1_1.get(), - expected_splits_1_1); - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_splits_2_1.get(), - expected_splits_2_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 2}}, {2, 2}, {1, 2, 4, 5}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 1}}, {1, 2}, {6, 7}), + *encoded_list(1).get()); } TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { @@ -263,30 +213,6 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { const std::vector batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15}; const std::vector batched_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - const std::vector component_splits_1_1 = {0, 1, 3, 3}; - const std::vector component_splits_2_1 = {0}; - const std::vector component_splits_3_1 = {0, 5, 8}; - const std::vector component_splits_4_1 = {0, 0, 4}; - const std::vector component_values_1 = {1, 2, 3}; - const std::vector component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11}; - const std::vector component_values_4 = {12, 13, 14, 15}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({4})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({1})); - Tensor expected_splits_3_1(DT_INT64, TensorShape({3})); - Tensor expected_splits_4_1(DT_INT64, TensorShape({3})); - Tensor expected_values_1(DT_INT32, TensorShape({3})); - Tensor expected_values_2(DT_INT32, TensorShape({0})); - Tensor expected_values_3(DT_INT32, TensorShape({8})); - Tensor expected_values_4(DT_INT32, TensorShape({4})); - - test::FillValues(&expected_splits_1_1, component_splits_1_1); - test::FillValues(&expected_splits_2_1, component_splits_2_1); - test::FillValues(&expected_splits_3_1, component_splits_3_1); - test::FillValues(&expected_splits_4_1, component_splits_4_1); - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_3, component_values_3); - test::FillValues(&expected_values_4, component_values_4); BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, TensorShape({15}), batched_values, @@ -296,39 +222,19 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 4); - const Variant& encoded_splits_1_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(1); - const Variant& encoded_splits_2_1 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(1); - const Variant& encoded_splits_3_1 = - encoded_list(2).get()->vec()(0); - const Variant& encoded_values_3 = - encoded_list(2).get()->vec()(1); - const Variant& encoded_splits_4_1 = - encoded_list(3).get()->vec()(0); - const Variant& encoded_values_4 = - encoded_list(3).get()->vec()(1); - - test::ExpectTensorEqual(*encoded_splits_1_1.get(), - expected_splits_1_1); - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_splits_2_1.get(), - expected_splits_2_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); - test::ExpectTensorEqual(*encoded_splits_3_1.get(), - expected_splits_3_1); - test::ExpectTensorEqual(*encoded_values_3.get(), - expected_values_3); - test::ExpectTensorEqual(*encoded_splits_4_1.get(), - expected_splits_4_1); - test::ExpectTensorEqual(*encoded_values_4.get(), - expected_values_4); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 1, 3, 3}}, {1, 2, 3}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0}}, {}), + *encoded_list(1).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 5, 8}}, + {4, 5, 6, 7, 8, 9, 10, 11}), + *encoded_list(2).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({{0, 0, 4}}, {12, 13, 14, 15}), + *encoded_list(3).get()); } TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { @@ -350,26 +256,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { 7, 8, 9, 12, 13, 14}; const std::vector batched_values = {0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 8, 9}; - const std::vector component_split_1_1 = {0, 1, 3, 4, 5, 6}; - const std::vector component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; - const std::vector component_split_2_1 = {0, 1, 2, 3, 4, 5}; - const std::vector component_split_2_2 = {0, 1, 2, 5, 6, 7}; - const std::vector component_values_1 = {0, 1, 1, 2, 2, 3, 4}; - const std::vector component_values_2 = {5, 6, 7, 8, 9, 8, 9}; - - Tensor expected_splits_1_1(DT_INT64, TensorShape({6})); - Tensor expected_splits_1_2(DT_INT64, TensorShape({7})); - Tensor expected_splits_2_1(DT_INT64, TensorShape({6})); - Tensor expected_splits_2_2(DT_INT64, TensorShape({6})); - Tensor expected_values_1(DT_INT32, TensorShape({7})); - Tensor expected_values_2(DT_INT32, TensorShape({7})); - - test::FillValues(&expected_splits_1_1, component_split_1_1); - test::FillValues(&expected_splits_1_2, component_split_1_2); - test::FillValues(&expected_splits_2_1, component_split_2_1); - test::FillValues(&expected_splits_2_2, component_split_2_2); - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_2, component_values_2); BuildEncodeRaggedTensorGraph( {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), @@ -379,31 +265,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_splits_1_2 = - encoded_list(0).get()->vec()(1); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(2); - const Variant& encoded_splits_2_1 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_splits_2_2 = - encoded_list(1).get()->vec()(1); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(2); - - test::ExpectTensorEqual(*encoded_splits_1_1.get(), - expected_splits_1_1); - test::ExpectTensorEqual(*encoded_splits_1_2.get(), - expected_splits_1_2); - test::ExpectTensorEqual(*encoded_splits_2_1.get(), - expected_splits_2_1); - test::ExpectTensorEqual(*encoded_splits_2_2.get(), - expected_splits_2_2); - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged( + {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged( + {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}), + *encoded_list(1).get()); } TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { @@ -424,28 +293,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { 7, 8, 9, 12, 13, 14}; const std::vector batched_values = {0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 8, 9}; - const std::vector component_split_1_1 = {0, 1, 3, 4, 5, 6}; - const std::vector component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; - const std::vector component_split_2_1 = {0, 1, 2, 3, 4, 5}; - const std::vector component_split_2_2 = {0, 1, 2, 5, 6, 7}; - const std::vector component_values_1 = {0, 1, 1, 2, 2, 3, 4}; - const std::vector component_values_2 = {5, 6, 7, 8, 9, 8, 9}; - Tensor expected_splits_1_1(DT_INT32, TensorShape({6})); - Tensor expected_splits_1_2(DT_INT32, TensorShape({7})); - Tensor expected_splits_2_1(DT_INT32, TensorShape({6})); - Tensor expected_splits_2_2(DT_INT32, TensorShape({6})); - Tensor expected_values_1(DT_INT32, TensorShape({7})); - Tensor expected_values_2(DT_INT32, TensorShape({7})); - - test::FillValues(&expected_splits_1_1, component_split_1_1); - test::FillValues(&expected_splits_1_2, component_split_1_2); - test::FillValues(&expected_splits_2_1, component_split_2_1); - test::FillValues(&expected_splits_2_2, component_split_2_2); - test::FillValues(&expected_values_1, component_values_1); - test::FillValues(&expected_values_2, component_values_2); - - BuildEncodeRaggedTensorGraph( + BuildEncodeRaggedTensorGraph( {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), batched_values, true); TF_ASSERT_OK(RunOpKernel()); @@ -453,31 +302,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { const auto& encoded_list = GetOutput(0)->vec(); EXPECT_EQ(encoded_list.size(), 2); - const Variant& encoded_splits_1_1 = - encoded_list(0).get()->vec()(0); - const Variant& encoded_splits_1_2 = - encoded_list(0).get()->vec()(1); - const Variant& encoded_values_1 = - encoded_list(0).get()->vec()(2); - const Variant& encoded_splits_2_1 = - encoded_list(1).get()->vec()(0); - const Variant& encoded_splits_2_2 = - encoded_list(1).get()->vec()(1); - const Variant& encoded_values_2 = - encoded_list(1).get()->vec()(2); - - test::ExpectTensorEqual(*encoded_splits_1_1.get(), - expected_splits_1_1); - test::ExpectTensorEqual(*encoded_splits_1_2.get(), - expected_splits_1_2); - test::ExpectTensorEqual(*encoded_splits_2_1.get(), - expected_splits_2_1); - test::ExpectTensorEqual(*encoded_splits_2_2.get(), - expected_splits_2_2); - test::ExpectTensorEqual(*encoded_values_1.get(), - expected_values_1); - test::ExpectTensorEqual(*encoded_values_2.get(), - expected_values_2); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged( + {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}), + *encoded_list(0).get()); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged( + {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}), + *encoded_list(1).get()); } TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) { @@ -491,33 +323,17 @@ TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) { const std::vector batched_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5})); - Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8})); - Tensor batched_ragged_values(DT_INT32, TensorShape({15})); - - test::FillValues(&batched_ragged_splits_1, batched_splits_1); - test::FillValues(&batched_ragged_splits_2, batched_splits_2); - test::FillValues(&batched_ragged_values, batched_values); - BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, TensorShape({15}), batched_values, false); TF_ASSERT_OK(RunOpKernel()); const auto& encoded_scalar = GetOutput(0)->scalar()(); - const Variant& encoded_splits_1 = - encoded_scalar.get()->vec()(0); - const Variant& encoded_splits_2 = - encoded_scalar.get()->vec()(1); - const Variant& encoded_values = - encoded_scalar.get()->vec()(2); - test::ExpectTensorEqual(*encoded_splits_1.get(), - batched_ragged_splits_1); - test::ExpectTensorEqual(*encoded_splits_2.get(), - batched_ragged_splits_2); - test::ExpectTensorEqual(*encoded_values.get(), - batched_ragged_values); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({batched_splits_1, batched_splits_2}, + batched_values), + *encoded_scalar.get()); } TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) { @@ -598,17 +414,14 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) { TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) { const std::vector values = {1, 2, 3, 4, 5, 6}; - Tensor expected_values(DT_INT32, TensorShape({6})); - test::FillValues(&expected_values, values); BuildEncodeRaggedTensorGraph({}, TensorShape({6}), values, false); TF_ASSERT_OK(RunOpKernel()); const auto& encoded_scalar = GetOutput(0)->scalar()(); - const Variant& encoded_values = - encoded_scalar.get()->vec()(0); - - test::ExpectTensorEqual(*encoded_values.get(), expected_values); + ExpectRaggedTensorVariantEqual( + CreateVariantFromRagged({}, values), + *encoded_scalar.get()); } } // namespace diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc new file mode 100644 index 00000000000..9466313819b --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/ragged_tensor_variant.h" + +namespace tensorflow { + +string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; } + +string RaggedTensorVariant::DebugString() const { + return absl::StrCat( + "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()), + ", ragged_rank=", nested_splits_.size(), ", splits_dtype=", + DataTypeString(nested_splits_.empty() ? DT_INVALID + : nested_splits_.back().dtype())); +} + +void RaggedTensorVariant::Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + for (const auto& splits : nested_splits_) { + *data->add_tensors() = splits; + } + *data->add_tensors() = values_; +} + +bool RaggedTensorVariant::Decode(const VariantTensorData& data) { + if (data.tensors_size() < 1) { + return false; + } + nested_splits_.assign(data.tensors().begin(), + std::prev(data.tensors().end())); + values_ = data.tensors().back(); + return true; +} + +namespace { + +Status RaggedTensorVariantDeviceCopy( + const RaggedTensorVariant& from, RaggedTensorVariant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values())); + // TODO(b/170415165) Should we use `copy` to move splits from device<->host? + *to->mutable_nested_splits() = from.nested_splits(); + return Status::OK(); +} + +} // namespace + +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantZerosLike); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION( + ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantBinaryAdd); + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant, + "RaggedTensorVariant"); + +#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy) + +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); +REGISTER_RAGGED_TENSOR_VARIANT_COPY( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h new file mode 100644 index 00000000000..730758a3e82 --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.h @@ -0,0 +1,110 @@ +#include "tensorflow/core/framework/tensor_key.h" +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { + +// Class used to store a RaggedTensor as a Variant scalar. +class RaggedTensorVariant { + public: + RaggedTensorVariant() {} + RaggedTensorVariant(Tensor values, const std::vector& nested_splits) + : values_(std::move(values)), nested_splits_(nested_splits) {} + + // Variant support methods. + string TypeName() const; + string DebugString() const; + void Encode(VariantTensorData* data) const; + bool Decode(const VariantTensorData& data); + + // The flat_values of the RaggedTensor. + const Tensor& values() const { return values_; } + Tensor* mutable_values() { return &values_; } + void set_values(const Tensor& new_values) { values_ = new_values; } + + // The nested row_splits of the RaggedTensor. + int ragged_rank() const { return nested_splits_.size(); } + const std::vector& nested_splits() const { return nested_splits_; } + std::vector* mutable_nested_splits() { return &nested_splits_; } + const Tensor& splits(int i) const { return nested_splits_[i]; } + Tensor* mutable_splits(int i) { return &nested_splits_[i]; } + void set_nested_splits(const std::vector& nested_splits) { + nested_splits_ = nested_splits; + } + void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); } + + private: + Tensor values_; + std::vector nested_splits_; +}; + +template +Status RaggedTensorVariantZerosLike(OpKernelContext* c, + const RaggedTensorVariant& x, + RaggedTensorVariant* y) { + y->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR( + ZerosLikeTensor(c, x.values(), y->mutable_values())); + return Status::OK(); +} + +template +Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, + const RaggedTensorVariant& x, + const RaggedTensorVariant& y, + RaggedTensorVariant* out) { + if (x.values().dtype() != y.values().dtype()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different dtypes. One is ", + DataTypeString(x.values().dtype()), " and the other is ", + DataTypeString(y.values().dtype())); + } + if (x.ragged_rank() != y.ragged_rank()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different ragged rank. ", "One is ", + x.ragged_rank(), " and the other is ", y.ragged_rank()); + } + for (int i = 0; i < x.ragged_rank(); ++i) { + if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants with different row_splits."); + } + } + out->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR(BinaryAddTensors(c, x.values(), y.values(), + out->mutable_values())); + return Status::OK(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 59184ab061d..fc439a08df1 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -73,6 +73,20 @@ struct DividesBy { __host__ __device__ OUT_T operator()(const T& x) const { return x / divisor; } }; +struct MaxPropagateNaN { + template + __host__ __device__ inline T operator()(const T& a, const T& b) const { + return (a != a ? a : (a > b ? a : b)); + } +}; + +struct MinPropagateNaN { + template + __host__ __device__ inline T operator()(const T& a, const T& b) const { + return (a != a ? a : (a < b ? a : b)); + } +}; + #if GOOGLE_CUDA // TODO(rocm) : enable this once ROCm platform has support for complex datatypes // @@ -986,15 +1000,19 @@ struct IsSum { template struct IsMax { constexpr static bool value = - (std::is_same::value || - std::is_same>::value); + (std::is_same::value || + std::is_same::value || + std::is_same< + Op, Eigen::internal::MaxReducer>::value); }; template struct IsMin { constexpr static bool value = - (std::is_same::value || - std::is_same>::value); + (std::is_same::value || + std::is_same::value || + std::is_same< + Op, Eigen::internal::MinReducer>::value); }; template @@ -1222,41 +1240,47 @@ struct ReduceFunctor> { }; template -struct ReduceFunctor> { +struct ReduceFunctor> { template - static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, - const ReductionAxes& reduction_axes, - const Eigen::internal::MaxReducer& reducer) { - ReduceImpl( + static void Reduce( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MaxReducer& reducer) { + ReduceImpl( ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1, in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, - gpuprim::Max()); + MaxPropagateNaN()); } template - static void FillIdentity(const GPUDevice& d, OUT_T out, - const Eigen::internal::MaxReducer& reducer) { + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::MaxReducer& reducer) { FillIdentityEigenImpl(d, To32Bit(out), reducer); } }; template -struct ReduceFunctor> { +struct ReduceFunctor> { template - static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, - const ReductionAxes& reduction_axes, - const Eigen::internal::MinReducer& reducer) { - ReduceImpl( + static void Reduce( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MinReducer& reducer) { + ReduceImpl( ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1, in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, - gpuprim::Min()); + MinPropagateNaN()); } template - static void FillIdentity(const GPUDevice& d, OUT_T out, - const Eigen::internal::MinReducer& reducer) { + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::MinReducer& reducer) { FillIdentityEigenImpl(d, To32Bit(out), reducer); } }; diff --git a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc index c952c4c9fa4..dfd31795b35 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc @@ -44,22 +44,27 @@ typedef TTypes::Tensor::Index Index; template void ReduceFunctor::FillIdentity( \ const GPUDevice& d, TTypes::Vec out, const REDUCER& reducer); -#define DEFINE_FOR_TYPE_AND_R(T, R) \ - DEFINE(T, R, 1, 1); \ - DEFINE(T, R, 2, 1); \ - DEFINE(T, R, 3, 1); \ - DEFINE(T, R, 3, 2); \ - DEFINE_IDENTITY(T, R) +#define SINGLE_ARG(...) __VA_ARGS__ -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_TYPE_AND_R(T, R) \ + DEFINE(T, SINGLE_ARG(R), 1, 1); \ + DEFINE(T, SINGLE_ARG(R), 2, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 2); \ + DEFINE_IDENTITY(T, SINGLE_ARG(R)) + +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MinReducer)); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MaxReducer)); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(double); +#undef SINGLE_ARG #undef DEFINE_FOR_ALL_REDUCERS #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc index 92f4b9d707c..bf9831a1207 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc @@ -44,22 +44,27 @@ typedef TTypes::Tensor::Index Index; template void ReduceFunctor::FillIdentity( \ const GPUDevice& d, TTypes::Vec out, const REDUCER& reducer); -#define DEFINE_FOR_TYPE_AND_R(T, R) \ - DEFINE(T, R, 1, 1); \ - DEFINE(T, R, 2, 1); \ - DEFINE(T, R, 3, 1); \ - DEFINE(T, R, 3, 2); \ - DEFINE_IDENTITY(T, R) +#define SINGLE_ARG(...) __VA_ARGS__ -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_TYPE_AND_R(T, R) \ + DEFINE(T, SINGLE_ARG(R), 1, 1); \ + DEFINE(T, SINGLE_ARG(R), 2, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 2); \ + DEFINE_IDENTITY(T, SINGLE_ARG(R)) + +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MinReducer)); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MaxReducer)); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(float); +#undef SINGLE_ARG #undef DEFINE_FOR_ALL_REDUCERS #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc index c35d8c2ec86..2efcad02950 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc @@ -44,23 +44,28 @@ typedef TTypes::Tensor::Index Index; template void ReduceFunctor::FillIdentity( \ const GPUDevice& d, TTypes::Vec out, const REDUCER& reducer); -#define DEFINE_FOR_TYPE_AND_R(T, R) \ - DEFINE(T, R, 1, 1); \ - DEFINE(T, R, 2, 1); \ - DEFINE(T, R, 3, 1); \ - DEFINE(T, R, 3, 2); \ - DEFINE_IDENTITY(T, R) +#define SINGLE_ARG(...) __VA_ARGS__ -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_TYPE_AND_R(T, R) \ + DEFINE(T, SINGLE_ARG(R), 1, 1); \ + DEFINE(T, SINGLE_ARG(R), 2, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 2); \ + DEFINE_IDENTITY(T, SINGLE_ARG(R)) + +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MinReducer)); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MaxReducer)); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(int32); DEFINE_FOR_ALL_REDUCERS(int64); +#undef SINGLE_ARG #undef DEFINE_FOR_ALL_REDUCERS #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc b/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc index d2a180ba351..23c9ec9e592 100644 --- a/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_half_prod_max_min.cu.cc @@ -44,19 +44,24 @@ typedef TTypes::Tensor::Index Index; template void ReduceFunctor::FillIdentity( \ const GPUDevice& d, TTypes::Vec out, const REDUCER& reducer); -#define DEFINE_FOR_TYPE_AND_R(T, R) \ - DEFINE(T, R, 1, 1); \ - DEFINE(T, R, 2, 1); \ - DEFINE(T, R, 3, 1); \ - DEFINE(T, R, 3, 2); \ - DEFINE_IDENTITY(T, R) +#define SINGLE_ARG(...) __VA_ARGS__ -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_TYPE_AND_R(T, R) \ + DEFINE(T, SINGLE_ARG(R), 1, 1); \ + DEFINE(T, SINGLE_ARG(R), 2, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 1); \ + DEFINE(T, SINGLE_ARG(R), 3, 2); \ + DEFINE_IDENTITY(T, SINGLE_ARG(R)) + +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MinReducer)); \ + DEFINE_FOR_TYPE_AND_R( \ + T, SINGLE_ARG(Eigen::internal::MaxReducer)); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(Eigen::half); +#undef SINGLE_ARG #undef DEFINE_FOR_ALL_REDUCERS #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc index 99b17f402af..c52874818f0 100644 --- a/tensorflow/core/kernels/reduction_ops_max.cc +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -17,39 +17,43 @@ limitations under the License. namespace tensorflow { -#define REGISTER_CPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Max") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Max") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Max") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Max") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc index be1d09352e0..ee5310a469a 100644 --- a/tensorflow/core/kernels/reduction_ops_min.cc +++ b/tensorflow/core/kernels/reduction_ops_min.cc @@ -17,39 +17,43 @@ limitations under the License. namespace tensorflow { -#define REGISTER_CPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Min") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Min") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Min") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Min") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); diff --git a/tensorflow/core/kernels/scan_ops.h b/tensorflow/core/kernels/scan_ops.h index 8afcac86c3f..0584e0ac77c 100644 --- a/tensorflow/core/kernels/scan_ops.h +++ b/tensorflow/core/kernels/scan_ops.h @@ -111,7 +111,7 @@ struct LogSumExpReducer { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const { - auto max_reducer = Eigen::internal::MaxReducer(); + auto max_reducer = Eigen::internal::MaxReducer(); auto sum_reducer = Eigen::internal::SumReducer(); auto exp = Eigen::internal::scalar_exp_op(); auto cmp_lt = diff --git a/tensorflow/core/kernels/tile_functor_gpu.h b/tensorflow/core/kernels/tile_functor_gpu.h index a8db29926fc..191553f462e 100644 --- a/tensorflow/core/kernels/tile_functor_gpu.h +++ b/tensorflow/core/kernels/tile_functor_gpu.h @@ -66,7 +66,7 @@ void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in) { } // Copies the input strides, output strides and input dimension sizes to the // device. - auto num_bytes = sizeof(int64) * host_buf.size(); + auto num_bytes = sizeof(int32) * host_buf.size(); auto dev_buf = d.allocate(num_bytes); // NOTE: host_buf is not allocated by GpuHostAllocator, and // therefore we are doing a sync copy effectively. diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index 3390bd07308..77fb521941f 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -244,7 +244,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS_NAME #undef REGISTER_KERNELS -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -277,6 +277,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#endif // end GOOGLE_CUDA +#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index d26dd7a8bc3..481025fbf01 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -15,11 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ #define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include +#include #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -39,7 +40,7 @@ limitations under the License. namespace cub { template <> struct NumericTraits - : BaseTraits {}; + : BaseTraits {}; } // namespace cub #endif // GOOGLE_CUDA @@ -93,7 +94,6 @@ struct IndirectLinearData { Entry* const backing_data; }; -#if GOOGLE_CUDA template struct StridedData { typedef impl::Entry Entry; @@ -107,7 +107,6 @@ struct StridedData { Entry* const data; }; -#endif // A heap of Entry that can either work as a min-heap or as a max-heap. template ::Entry Entry; const Data data; + __device__ IndexedHeap(const Data& d) : data(d) {} __device__ bool is_above(int left, int right) { T left_value = data.get_value(left); @@ -337,12 +337,21 @@ __device__ void mergeShards(int num_shards, int k, } } +#if GOOGLE_CUDA extern __shared__ char shared_memory[]; +#endif // GOOGLE_CUDA template -__global__ void TopKKernel(const T* __restrict__ input, int length, int k, - bool sorted, T* __restrict__ output, - int* __restrict__ indices) { +#if TENSORFLOW_USE_ROCM +__attribute__((amdgpu_flat_work_group_size(1, 256))) +#endif // TENSORFLOW_USE_ROCM +__global__ void +TopKKernel(const T* __restrict__ input, int length, int k, bool sorted, + T* __restrict__ output, int* __restrict__ indices) { +#if TENSORFLOW_USE_ROCM + HIP_DYNAMIC_SHARED(char, shared_memory); +#endif // TENSORFLOW_USE_ROCM + const int batch_index = blockIdx.x; const T* batch_input = input + batch_index * length; @@ -370,7 +379,7 @@ __global__ void TopKKernel(const T* __restrict__ input, int length, int k, } template -cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards, +cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards, const T* input, int batch_size, int length, int k, bool sorted, T* output, int* indices) { // This code assumes that k is small enough that the computation @@ -395,9 +404,17 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards, } if (num_shards <= 0) { num_shards = 1; +#if GOOGLE_CUDA } else if (num_shards > 1024) { num_shards = 1024; } +#elif TENSORFLOW_USE_ROCM + // ROCm can't execute with 1024 and requires an explicit + // amdgpu_flat_work_group_size attribute with >256 + } else if (num_shards > 256) { + num_shards = 256; + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // We are limited by the amount of shared memory we have per block. auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry); @@ -439,8 +456,8 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, const auto& cu_stream = GetGpuStream(ctx); size_t temp_storage_bytes = -1; - // TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor - // with an iterator that directly returns the correct value. + // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that + // tensor with an iterator that directly returns the correct value. Tensor input_indices; TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_INT32, TensorShape({num_rows, num_cols}), &input_indices)); @@ -448,9 +465,9 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, input_indices_t.device(d) = input_indices_t.generate(ColumnIndexCreator(num_cols)); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator> + gpuprim::CountingInputIterator counting_iter(0); + gpuprim::TransformInputIterator> segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols)); Tensor temp_values; @@ -472,7 +489,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, sorted_values_ptr = temp_values.flat().data(); } - auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( /* d_temp_storage */ nullptr, /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ input, @@ -489,7 +506,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, if (err != cudaSuccess) { return errors::Internal( "TopKOp: Could not launch " - "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate " + "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate " "temp_storage_bytes, status: ", cudaGetErrorString(err)); } @@ -497,7 +514,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), &temp_storage)); - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( /* d_temp_storage */ temp_storage.flat().data(), /* temp_storage_bytes */ temp_storage_bytes, /* d_keys_in */ input, @@ -514,7 +531,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, if (err != cudaSuccess) { return errors::Internal( "TopKOp: Could not launch " - "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " + "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " "temp_storage_bytes: ", temp_storage_bytes, ", status: ", cudaGetErrorString(err)); } @@ -543,8 +560,8 @@ struct TopKFunctor { const int64 num_cols, typename TTypes::Tensor values, typename TTypes::Tensor indices) { // For small k, use the heap implementation. For larger k, use - // the in-place cub sort. For k == num_cols, always use the - // in-place cub sort. The thresholds for n and k were determined + // the in-place gpuprim sort. For k == num_cols, always use the + // in-place gpuprim sort. The thresholds for n and k were determined // empirically. if (num_cols <= 1000 || k == num_cols || k >= 100) { return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols, @@ -567,6 +584,6 @@ struct TopKFunctor { } // end namespace functor } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ diff --git a/tensorflow/core/kernels/topk_op_gpu_double.cu.cc b/tensorflow/core/kernels/topk_op_gpu_double.cu.cc index 8a5a7e71b1b..787aafdfd07 100644 --- a/tensorflow/core/kernels/topk_op_gpu_double.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_double.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_float.cu.cc b/tensorflow/core/kernels/topk_op_gpu_float.cu.cc index 0b69396bb13..10d106248f9 100644 --- a/tensorflow/core/kernels/topk_op_gpu_float.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_float.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_half.cu.cc b/tensorflow/core/kernels/topk_op_gpu_half.cu.cc index e53586aeca2..bde26cb0951 100644 --- a/tensorflow/core/kernels/topk_op_gpu_half.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_half.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_int16.cu.cc b/tensorflow/core/kernels/topk_op_gpu_int16.cu.cc index 5bd310523c9..fba39300700 100644 --- a/tensorflow/core/kernels/topk_op_gpu_int16.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_int16.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_int32.cu.cc b/tensorflow/core/kernels/topk_op_gpu_int32.cu.cc index 55b393a0c02..a017234597d 100644 --- a/tensorflow/core/kernels/topk_op_gpu_int32.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_int32.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_int64.cu.cc b/tensorflow/core/kernels/topk_op_gpu_int64.cu.cc index 3e4a7750563..ed9f6ea52c6 100644 --- a/tensorflow/core/kernels/topk_op_gpu_int64.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_int64.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_int8.cu.cc b/tensorflow/core/kernels/topk_op_gpu_int8.cu.cc index ac73cd170b8..647700ebcda 100644 --- a/tensorflow/core/kernels/topk_op_gpu_int8.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_int8.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_uint16.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint16.cu.cc index bc64a2ecd63..41ab6ffa601 100644 --- a/tensorflow/core/kernels/topk_op_gpu_uint16.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_uint16.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -27,4 +27,4 @@ template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc index 16e2e0e9420..6725f478c15 100644 --- a/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc index 895247a63a2..0dd65145d41 100644 --- a/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu_uint8.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint8.cu.cc index fc1a8a2c8cc..6d544291fed 100644 --- a/tensorflow/core/kernels/topk_op_gpu_uint8.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu_uint8.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "tensorflow/core/kernels/topk_op.h" @@ -25,4 +25,4 @@ using Eigen::GpuDevice; template struct functor::TopKFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index aaa0310cd1b..f2f6889c900 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/util/util.h" - namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; @@ -53,7 +52,6 @@ struct ApplyGradientDescent { } }; - template struct ApplyAdadelta { void operator()(const CPUDevice& d, typename TTypes::Flat var, @@ -350,6 +348,179 @@ struct ApplyFtrlMultiplyLinearByLr { } }; +namespace { + +template +inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, + const T& l2, const T& lr_power, + const bool multiply_linear_by_lr) { + T quadratic; + if (multiply_linear_by_lr) { + if (lr_power == static_cast(-0.5)) { + quadratic = Eigen::numext::sqrt(accum) + static_cast(2) * l2 * lr; + } else { + quadratic = + Eigen::numext::pow(accum, -lr_power) + static_cast(2) * l2 * lr; + } + auto l1_reg_adjust = std::max(std::min(linear, l1 * lr), -l1 * lr); + return (l1_reg_adjust - linear) / quadratic; + } else { + if (lr_power == static_cast(-0.5)) { + quadratic = Eigen::numext::sqrt(accum) / lr + static_cast(2) * l2; + } else { + quadratic = + Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; + } + auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); + return (l1_reg_adjust - linear) / quadratic; + } +} + +} // namespace + +template +struct SparseApplyFtrl { + Status operator()(const CPUDevice& d, typename TTypes::Matrix var_flat, + typename TTypes::Matrix accum_flat, + typename TTypes::Matrix linear_flat, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power, + typename TTypes::ConstMatrix grad_flat, + typename TTypes::ConstVec indices_vec, + int64 inner_dim, bool multiply_linear_by_lr) { + const Tindex N = static_cast(indices_vec.dimension(0)); + if (N > 0) { + T lr_scalar = lr(); + T l1_scalar = l1(); + T l2_scalar = l2(); + T l2_shrinkage_scalar; + if (has_l2_shrinkage) { + l2_shrinkage_scalar = l2_shrinkage(); + } + T lr_power_scalar = lr_power(); + if (inner_dim > 1) { + const Tindex first_dim_size = + static_cast(var_flat.dimension(0)); + + for (Tindex i = 0; i < N; i++) { + const Tindex index = internal::SubtleMustCopy(indices_vec(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); + } + auto accum = accum_flat.template chip<0>(index); + auto linear = linear_flat.template chip<0>(index); + auto grad = grad_flat.template chip<0>(i); + auto var = var_flat.template chip<0>(index); + +// TODO(sanjoy): Remove this macro. +// Use a macro to implement the computation here due to the templating of the +// eigen tensor library. +#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \ + auto new_accum = accum + grad.square(); \ + if (multiply_linear_by_lr) { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + linear += grad_maybe_with_shrinkage * lr_scalar - \ + (new_accum.sqrt() - accum.sqrt()) * var; \ + } else { \ + linear += \ + grad_maybe_with_shrinkage * lr_scalar - \ + (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * \ + var; \ + } \ + } else { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + linear += grad_maybe_with_shrinkage - \ + (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ + } else { \ + linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ + accum.pow(-lr_power_scalar)) / \ + lr_scalar * var; \ + } \ + } \ + auto l1_reg_adjust = \ + (multiply_linear_by_lr \ + ? linear.cwiseMin(l1_scalar * lr_scalar) \ + .cwiseMax(-l1_scalar * lr_scalar) \ + : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); \ + auto x = l1_reg_adjust - linear; \ + if (multiply_linear_by_lr) { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + auto y = new_accum.sqrt() + \ + linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ + var = x / y; \ + } else { \ + auto y = new_accum.pow(-lr_power_scalar) + \ + linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ + var = x / y; \ + } \ + } else { \ + if (lr_power_scalar == static_cast(-0.5)) { \ + auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ + linear.constant(static_cast(2) * l2_scalar); \ + var = x / y; \ + } else { \ + auto y = \ + new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ + linear.constant(static_cast(2) * l2_scalar); \ + var = x / y; \ + } \ + } \ + accum += grad.square(); + + if (has_l2_shrinkage) { + auto grad_with_shrinkage = + grad + static_cast(2) * l2_shrinkage_scalar * var; + COMPUTE_FTRL(grad, grad_with_shrinkage); + } else { + COMPUTE_FTRL(grad, grad); + } + } +#undef COMPUTE_FTRL + } else { + const Tindex first_dim_size = accum_flat.size(); + + for (Tindex i = 0; i < N; i++) { + const Tindex index = internal::SubtleMustCopy(indices_vec(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); + } + T& a = accum_flat(index); + T& l = linear_flat(index); + T& v = var_flat(index); + T g; + if (has_l2_shrinkage) { + g = grad_flat(i) + + (static_cast(2) * l2_shrinkage_scalar * var_flat(index)); + } else { + g = grad_flat(i); + } + + T updated_a = a + grad_flat(i) * grad_flat(i); + using Eigen::numext::pow; + T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar); + if (!multiply_linear_by_lr) { + sigma /= lr_scalar; + } + T updated_l = (multiply_linear_by_lr ? l + g * lr_scalar - sigma * v + : l + g - sigma * v); + v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar, + lr_power_scalar, multiply_linear_by_lr); + a = updated_a; + l = updated_l; + } + } + } + return Status::OK(); + } +}; + template struct ApplyMomentum { void operator()(const CPUDevice& d, typename TTypes::Flat var, @@ -483,7 +654,6 @@ struct ApplyAdamNonCuda { } }; - template struct ApplyAdam : ApplyAdamNonCuda {}; @@ -638,7 +808,6 @@ class ApplyGradientDescentOp : public OpKernel { bool use_exclusive_lock_; }; - #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint("T"), \ @@ -682,7 +851,6 @@ REGISTER_KERNELS(GPU, complex128); #endif #endif - #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -694,24 +862,12 @@ class ApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - Var* resource; const bool sparse = false; - mutex* mu = GetTrainingVariableMutex(ctx, 0, sparse, &resource); - core::ScopedUnref scoped_unref(resource); - if (use_exclusive_lock_ && mu != nullptr) { - mutex_lock l1(*mu); - // Don't try to acquire a lock on the second ref as they share the same - // mutex. - // - // mutex_lock l2(*ctx->input_ref_mutex(1)); - DoValidate(ctx); - if (!ctx->status().ok()) return; - DoCompute(ctx); - } else { - DoValidate(ctx); - if (!ctx->status().ok()) return; - DoCompute(ctx); - } + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); + DoValidate(ctx); + if (!ctx->status().ok()) return; + DoCompute(ctx); MaybeForwardRefInputToRefOutput(ctx, 0, 0); } @@ -856,20 +1012,10 @@ class SparseApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - Var* var; const bool sparse = true; - mutex* mu = GetTrainingVariableMutex(ctx, 0, sparse, &var); - core::ScopedUnref scoped_unref(var); - // mu_accum is actually the same mutex as mu_var since currently we use a - // global mutex. - // - // mutex* mu_accum = ctx->input_ref_mutex(1); - if (use_exclusive_lock_ && mu != nullptr) { - mutex_lock ml(*mu); - DoCompute(ctx); - } else { - DoCompute(ctx); - } + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); + DoCompute(ctx); } void DoCompute(OpKernelContext* ctx) { @@ -1538,35 +1684,6 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS -namespace { - -template -inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, - const T& l2, const T& lr_power, - const bool multiply_linear_by_lr) { - T quadratic; - if (multiply_linear_by_lr) { - if (lr_power == static_cast(-0.5)) { - quadratic = Eigen::numext::sqrt(accum) + static_cast(2) * l2 * lr; - } else { - quadratic = - Eigen::numext::pow(accum, -lr_power) + static_cast(2) * l2 * lr; - } - auto l1_reg_adjust = std::max(std::min(linear, l1 * lr), -l1 * lr); - return (l1_reg_adjust - linear) / quadratic; - } else { - if (lr_power == static_cast(-0.5)) { - quadratic = Eigen::numext::sqrt(accum) / lr + static_cast(2) * l2; - } else { - quadratic = - Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; - } - auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); - return (l1_reg_adjust - linear) / quadratic; - } -} -} // namespace - // Note, this op works on cpu only. template class SparseApplyAdagradOp : public OpKernel { @@ -2701,146 +2818,18 @@ class SparseApplyFtrlOp : public OpKernel { l2_shrinkage->shape().DebugString())); } - if (N > 0) { - if (inner_dim > 1) { - const Tindex first_dim_size = var.dim_size(0); - auto indices_vec = indices.vec(); - auto var_flat = var.flat_outer_dims(); - auto accum_flat = accum.flat_outer_dims(); - auto linear_flat = linear.flat_outer_dims(); - auto grad_flat = grad.flat_outer_dims(); - T lr_scalar = lr.scalar()(); - T l1_scalar = l1.scalar()(); - T l2_scalar = l2.scalar()(); - T l2_shrinkage_scalar; - if (has_l2_shrinkage) { - l2_shrinkage_scalar = l2_shrinkage->scalar()(); - } - T lr_power_scalar = lr_power.scalar()(); - - for (Tindex i = 0; i < N; i++) { - const Tindex index = internal::SubtleMustCopy(indices_vec(i)); - OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), - errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range"))); - auto accum = accum_flat.template chip<0>(index); - auto linear = linear_flat.template chip<0>(index); - auto grad = grad_flat.template chip<0>(i); - auto var = var_flat.template chip<0>(index); - -// Use a macro to implement the computation here due to the templating of the -// eigen tensor library. -#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \ - auto new_accum = accum + grad.square(); \ - if (multiply_linear_by_lr_) { \ - if (lr_power_scalar == static_cast(-0.5)) { \ - linear += grad_maybe_with_shrinkage * lr_scalar - \ - (new_accum.sqrt() - accum.sqrt()) * var; \ - } else { \ - linear += \ - grad_maybe_with_shrinkage * lr_scalar - \ - (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * \ - var; \ - } \ - } else { \ - if (lr_power_scalar == static_cast(-0.5)) { \ - linear += grad_maybe_with_shrinkage - \ - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ - } else { \ - linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \ - accum.pow(-lr_power_scalar)) / \ - lr_scalar * var; \ - } \ - } \ - auto l1_reg_adjust = \ - (multiply_linear_by_lr_ \ - ? linear.cwiseMin(l1_scalar * lr_scalar) \ - .cwiseMax(-l1_scalar * lr_scalar) \ - : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); \ - auto x = l1_reg_adjust - linear; \ - if (multiply_linear_by_lr_) { \ - if (lr_power_scalar == static_cast(-0.5)) { \ - auto y = new_accum.sqrt() + \ - linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ - var = x / y; \ - } else { \ - auto y = new_accum.pow(-lr_power_scalar) + \ - linear.constant(static_cast(2) * l2_scalar * lr_scalar); \ - var = x / y; \ - } \ - } else { \ - if (lr_power_scalar == static_cast(-0.5)) { \ - auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast(2) * l2_scalar); \ - var = x / y; \ - } else { \ - auto y = \ - new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ - linear.constant(static_cast(2) * l2_scalar); \ - var = x / y; \ - } \ - } \ - accum += grad.square(); - - if (has_l2_shrinkage) { - auto grad_with_shrinkage = - grad + static_cast(2) * l2_shrinkage_scalar * var; - COMPUTE_FTRL(grad, grad_with_shrinkage); - } else { - COMPUTE_FTRL(grad, grad); - } - } -#undef COMPUTE_FTRL - } else { - T lr_scalar = lr.scalar()(); - T l1_scalar = l1.scalar()(); - T l2_scalar = l2.scalar()(); - T lr_power_scalar = lr_power.scalar()(); - T l2_shrinkage_scalar; - if (has_l2_shrinkage) { - l2_shrinkage_scalar = l2_shrinkage->scalar()(); - } - - auto indices_vec = indices.vec(); - auto var_flat = var.flat(); - auto accum_flat = accum.flat(); - auto linear_flat = linear.flat(); - auto grad_flat = grad.flat(); - const Tindex first_dim_size = accum_flat.size(); - - for (Tindex i = 0; i < N; i++) { - const Tindex index = internal::SubtleMustCopy(indices_vec(i)); - OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), - errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range"))); - T& a = accum_flat(index); - T& l = linear_flat(index); - T& v = var_flat(index); - T g; - if (has_l2_shrinkage) { - g = grad_flat(i) + - (static_cast(2) * l2_shrinkage_scalar * var_flat(index)); - } else { - g = grad_flat(i); - } - - T updated_a = a + grad_flat(i) * grad_flat(i); - using Eigen::numext::pow; - T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar); - if (!multiply_linear_by_lr_) { - sigma /= lr_scalar; - } - T updated_l = (multiply_linear_by_lr_ ? l + g * lr_scalar - sigma * v - : l + g - sigma * v); - v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar, - lr_power_scalar, multiply_linear_by_lr_); - a = updated_a; - l = updated_l; - } - } - } + const Device& device = ctx->template eigen_device(); + auto indices_vec = indices.vec(); + OP_REQUIRES_OK( + ctx, functor::SparseApplyFtrl()( + device, var.flat_outer_dims(), accum.flat_outer_dims(), + linear.flat_outer_dims(), lr.scalar(), l1.scalar(), + l2.scalar(), + // Note: Passing l2 as a placeholder when not has_l2_shrinkage + // (it will not be used). + has_l2_shrinkage ? l2_shrinkage->scalar() : l2.scalar(), + lr_power.scalar(), grad.flat_outer_dims(), indices_vec, + inner_dim, multiply_linear_by_lr_)); MaybeForwardRefInputToRefOutput(ctx, 0, 0); } @@ -3471,7 +3460,6 @@ class ApplyAdamOp : public OpKernel { bool use_nesterov_; }; - #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint("T"), \ @@ -3488,7 +3476,6 @@ class ApplyAdamOp : public OpKernel { TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); - #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index ef44b5f9659..5af603077a5 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -151,6 +152,21 @@ struct ApplyFtrlV2MultiplyLinearByLr { typename TTypes::ConstScalar lr_power); }; +template +struct SparseApplyFtrl { + Status operator()(const Device& d, typename TTypes::Matrix var_flat, + typename TTypes::Matrix accum_flat, + typename TTypes::Matrix linear_flat, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power, + typename TTypes::ConstMatrix grad_flat, + typename TTypes::ConstVec indices_vec, + int64 inner_dim, bool multiply_linear_by_lr); +}; + template struct ApplyMomentum { void operator()(const Device& d, typename TTypes::Flat var, diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index 659513d05ed..5fb47043654 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -85,7 +85,6 @@ uint8* Decode(const void* srcdata, int datasize, } int target_num_frames = gif_file->ImageCount; - if (!expand_animations) target_num_frames = 1; // Don't request more memory than needed for each frame, preventing OOM int max_frame_width = 0; @@ -101,6 +100,7 @@ uint8* Decode(const void* srcdata, int datasize, const int width = max_frame_width; const int height = max_frame_height; const int channel = 3; + if (!expand_animations) target_num_frames = 1; uint8* const dstdata = allocate_output(target_num_frames, width, height, channel); @@ -118,27 +118,36 @@ uint8* Decode(const void* srcdata, int datasize, if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width || img_desc->Height != height) { - // If the first frame does not fill the entire canvas then return error. + // If the first frame does not fill the entire canvas then fill the + // unoccupied canvas with zeros (black). if (k == 0) { - *error_string = "the first frame does not fill the canvas"; - return nullptr; + for (int i = 0; i < height; ++i) { + uint8* p_dst = this_dst + i * width * channel; + for (int j = 0; j < width; ++j) { + p_dst[j * channel + 0] = 0; + p_dst[j * channel + 1] = 0; + p_dst[j * channel + 2] = 0; + } + } + } else { + // Otherwise previous frame will be reused to fill the unoccupied + // canvas. + uint8* last_dst = dstdata + (k - 1) * width * channel * height; + for (int i = 0; i < height; ++i) { + uint8* p_dst = this_dst + i * width * channel; + uint8* l_dst = last_dst + i * width * channel; + for (int j = 0; j < width; ++j) { + p_dst[j * channel + 0] = l_dst[j * channel + 0]; + p_dst[j * channel + 1] = l_dst[j * channel + 1]; + p_dst[j * channel + 2] = l_dst[j * channel + 2]; + } + } } - // Otherwise previous frame will be reused to fill the unoccupied canvas. + imgLeft = std::max(imgLeft, 0); imgTop = std::max(imgTop, 0); imgRight = std::min(imgRight, width); imgBottom = std::min(imgBottom, height); - - uint8* last_dst = dstdata + (k - 1) * width * channel * height; - for (int i = 0; i < height; ++i) { - uint8* p_dst = this_dst + i * width * channel; - uint8* l_dst = last_dst + i * width * channel; - for (int j = 0; j < width; ++j) { - p_dst[j * channel + 0] = l_dst[j * channel + 0]; - p_dst[j * channel + 1] = l_dst[j * channel + 1]; - p_dst[j * channel + 2] = l_dst[j * channel + 2]; - } - } } ColorMapObject* color_map = this_image->ImageDesc.ColorMap diff --git a/tensorflow/core/lib/gif/testdata/red_black.gif b/tensorflow/core/lib/gif/testdata/red_black.gif new file mode 100644 index 00000000000..d32ddd3547d Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/red_black.gif differ diff --git a/tensorflow/core/lib/gif/testdata/squares.gif b/tensorflow/core/lib/gif/testdata/squares.gif new file mode 100644 index 00000000000..159f86355a8 Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/squares.gif differ diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index fc9148c0f04..eadfbd1fe2e 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -8,6 +8,7 @@ package( default_visibility = [ "//tensorflow/c/experimental/filesystem:__pkg__", "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", + "//tensorflow/core/lib/io/snappy:__pkg__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", ], @@ -183,56 +184,24 @@ cc_library( alwayslink = True, ) -cc_library( +alias( name = "snappy_inputbuffer", - srcs = ["snappy/snappy_inputbuffer.cc"], - hdrs = ["snappy/snappy_inputbuffer.h"], - deps = [ - ":inputstream_interface", - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform:env", - "//tensorflow/core/platform:macros", - "//tensorflow/core/platform:platform_port", - "//tensorflow/core/platform:types", - ], - alwayslink = True, + actual = "//tensorflow/core/lib/io/snappy:snappy_inputbuffer", ) -cc_library( - name = "snappy_outputbuffer", - srcs = ["snappy/snappy_outputbuffer.cc"], - hdrs = ["snappy/snappy_outputbuffer.h"], - deps = [ - "//tensorflow/core/lib/core:status", - "//tensorflow/core/platform", - "//tensorflow/core/platform:env", - "//tensorflow/core/platform:macros", - "//tensorflow/core/platform:platform_port", - "//tensorflow/core/platform:types", - ], - alwayslink = True, -) - -cc_library( +alias( name = "snappy_inputstream", - srcs = ["snappy/snappy_inputstream.cc"], - hdrs = ["snappy/snappy_inputstream.h"], - deps = [ - ":inputstream_interface", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:platform_port", - "@com_google_absl//absl/memory", - ], - alwayslink = True, + actual = "//tensorflow/core/lib/io/snappy:snappy_inputstream", ) -cc_library( +alias( + name = "snappy_outputbuffer", + actual = "//tensorflow/core/lib/io/snappy:snappy_outputbuffer", +) + +alias( name = "snappy_compression_options", - hdrs = ["snappy/snappy_compression_options.h"], - deps = [ - "//tensorflow/core/platform:types", - ], - alwayslink = True, + actual = "//tensorflow/core/lib/io/snappy:snappy_compression_options", ) cc_library( @@ -350,9 +319,6 @@ filegroup( "random_inputstream.h", "record_reader.cc", "record_reader.h", - "snappy/snappy_compression_options.h", - "snappy/snappy_inputstream.cc", - "snappy/snappy_inputstream.h", "table.cc", "table.h", "table_builder.cc", @@ -364,6 +330,9 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.cc", "zlib_inputstream.h", + "//tensorflow/core/lib/io/snappy:snappy_compression_options.h", + "//tensorflow/core/lib/io/snappy:snappy_inputstream.cc", + "//tensorflow/core/lib/io/snappy:snappy_inputstream.h", ], ) @@ -383,10 +352,6 @@ filegroup( "random_inputstream.h", "record_reader.h", "record_writer.h", - "snappy/snappy_compression_options.h", - "snappy/snappy_inputbuffer.h", - "snappy/snappy_inputstream.h", - "snappy/snappy_outputbuffer.h", "table.h", "table_builder.h", "table_options.h", @@ -394,6 +359,10 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", + "//tensorflow/core/lib/io/snappy:snappy_compression_options.h", + "//tensorflow/core/lib/io/snappy:snappy_inputbuffer.h", + "//tensorflow/core/lib/io/snappy:snappy_inputstream.h", + "//tensorflow/core/lib/io/snappy:snappy_outputbuffer.h", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -409,9 +378,9 @@ filegroup( "random_inputstream_test.cc", "record_reader_writer_test.cc", "recordio_test.cc", - "snappy/snappy_test.cc", "table_test.cc", "zlib_buffers_test.cc", + "//tensorflow/core/lib/io/snappy:snappy_test.cc", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -440,13 +409,13 @@ filegroup( srcs = [ "inputbuffer.h", "iterator.h", - "snappy/snappy_compression_options.h", - "snappy/snappy_inputbuffer.h", - "snappy/snappy_inputstream.h", - "snappy/snappy_outputbuffer.h", "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", + "//tensorflow/core/lib/io/snappy:snappy_compression_options.h", + "//tensorflow/core/lib/io/snappy:snappy_inputbuffer.h", + "//tensorflow/core/lib/io/snappy:snappy_inputstream.h", + "//tensorflow/core/lib/io/snappy:snappy_outputbuffer.h", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/io/snappy/BUILD b/tensorflow/core/lib/io/snappy/BUILD new file mode 100644 index 00000000000..3f9405cdd6a --- /dev/null +++ b/tensorflow/core/lib/io/snappy/BUILD @@ -0,0 +1,74 @@ +# Snappy targets. + +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) + +package( + default_visibility = [ + "//tensorflow/core/lib/io:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "snappy_compression_options.h", + "snappy_inputbuffer.h", + "snappy_inputstream.h", + "snappy_outputbuffer.h", + "snappy_inputstream.cc", + "snappy_test.cc", +]) + +cc_library( + name = "snappy_inputbuffer", + srcs = ["snappy_inputbuffer.cc"], + hdrs = ["snappy_inputbuffer.h"], + deps = [ + "//tensorflow/core/lib/core:status", + "//tensorflow/core/lib/io:inputstream_interface", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:types", + ], + alwayslink = True, +) + +cc_library( + name = "snappy_outputbuffer", + srcs = ["snappy_outputbuffer.cc"], + hdrs = ["snappy_outputbuffer.h"], + deps = [ + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:types", + ], + alwayslink = True, +) + +cc_library( + name = "snappy_inputstream", + srcs = ["snappy_inputstream.cc"], + hdrs = ["snappy_inputstream.h"], + deps = [ + "//tensorflow/core/lib/io:inputstream_interface", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:platform_port", + "@com_google_absl//absl/memory", + ], + alwayslink = True, +) + +cc_library( + name = "snappy_compression_options", + hdrs = ["snappy_compression_options.h"], + deps = [ + "//tensorflow/core/platform:types", + ], + alwayslink = True, +) diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc index aa80576365e..af820084df5 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -164,7 +164,7 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { cinfo.dct_method = flags.dct_method; // Determine the output image size before attempting decompress to prevent - // OOM'ing doing the decompress + // OOM'ing during the decompress jpeg_calc_output_dimensions(&cinfo); int64 total_size = static_cast(cinfo.output_height) * @@ -577,7 +577,7 @@ bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, SetSrc(&cinfo, srcdata, datasize, false); jpeg_read_header(&cinfo, TRUE); - jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo + jpeg_calc_output_dimensions(&cinfo); if (width) *width = cinfo.output_width; if (height) *height = cinfo.output_height; if (components) *components = cinfo.output_components; diff --git a/tensorflow/core/lib/jpeg/testdata/BUILD b/tensorflow/core/lib/jpeg/testdata/BUILD new file mode 100644 index 00000000000..cb245b6ddcd --- /dev/null +++ b/tensorflow/core/lib/jpeg/testdata/BUILD @@ -0,0 +1,14 @@ +# Description: +# JPEG test data packages. + +load("//tensorflow:tensorflow.bzl", "filegroup") + +package( + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = glob(["*.jpg"]), + visibility = ["//tensorflow/core:__pkg__"], +) diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index 56e2255ae99..bcdee71be18 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/nccl/collective_communicator.h" +#include "tensorflow/core/framework/cancellation.h" + #if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) #include "absl/memory/memory.h" @@ -77,7 +79,25 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique( compute_stream->parent(), compute_stream, gpu_info, col_ctx->input, - col_ctx->output, col_ctx->col_params.default_rank, std::move(done)); + col_ctx->output, col_ctx->col_params.default_rank, + /*done_callback=*/nullptr); + CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager(); + if (cancel_mgr == nullptr) { + participant->done_callback = std::move(done); + } else { + CancellationToken cancel_token = cancel_mgr->get_cancellation_token(); + cancel_mgr->RegisterCallback(cancel_token, [this]() { + nccl_manager_.StartAbort(errors::Cancelled("op cancelled")); + nccl_manager_.Reset(); + }); + participant->done_callback = [cancel_mgr, cancel_token, + done = std::move(done)](const Status& s) { + // Do not block on deregistration since this can be invoked by + // NcclManager::StartAbort() in the cancellation callback. + cancel_mgr->TryDeregisterCallback(cancel_token); + done(s); + }; + } NcclManager::Context context( nccl_collective_key, num_local_devices, num_global_devices, col_params.group.runtime_details.communicator_key, diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 157c255a316..eaa34d042ce 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -115,7 +115,7 @@ struct NcclManager::Communicator { : num_devices(members.size()), members(std::move(members)), key(key) {} const int num_devices; - const std::vector members; + std::vector members; const string key; }; @@ -304,7 +304,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, // Launching of kernels must be serialized so that, given collectives A and // B, and an order of them (e.g., A before B), then for each comm_stream // involved, the kernel for A is launched before the kernel for B. This is - // guaranteed currently be a global mutex controlling additions of the + // guaranteed currently by a global mutex controlling additions of the // kernels to per-stream launch queues. The launch queues are processed by // LoopKernelLaunches. for (auto& comm : communicators_) { @@ -739,6 +739,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { VLOG(2) << "call NcclAllReduce collective_key " << collective->collective_key << " participant " << p_idx + << " num_participants " << collective->participants.size() << " sendbuff " << sendbuff << " recvbuff " << recvbuff << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream << " cuda_stream " << cu_stream; @@ -849,10 +850,8 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { } void NcclManager::StartAbort(const Status& s) { - VLOG(1) << "NcclManager StartAbort"; absl::flat_hash_map collectives; - // After status_ is set to a non-OK one, there should be no further - // modifications to collectives_. + std::vector> communicators; { mutex_lock l(mu_); if (!status_.ok()) { @@ -863,7 +862,11 @@ void NcclManager::StartAbort(const Status& s) { } status_ = s; collectives.swap(collectives_); + communicators.swap(communicators_); } + VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size() + << " collectives and " << communicators.size() + << " comms with status " << s; // collectives_ contains pending launches that haven't been dispatched to // kernel launch threads, so we can simply invoke the done callbacks of them. for (const auto& item : collectives) { @@ -872,21 +875,23 @@ void NcclManager::StartAbort(const Status& s) { } item.second->Unref(); } - // Abort ncclComm. Note that there could be multiple ncclComm per device, and - // ncclCommAbort contains cuda calls that requires device synchronization. - // That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1), - // so we need to abort all ncclComm in a concurrent fashion. This assumes that - // there's only one active NcclManager at a time. + // Abort ncclComm. Note that there could be multiple ncclComm per device, + // and ncclCommAbort contains cuda calls that requires device + // synchronization. That is a collective on nccl_comm_0 can block + // ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a + // concurrent fashion. This assumes that there's only one active NcclManager + // at a time. UnboundedWorkQueue queue(Env::Default(), "nccl_abort"); int num_comms = 0; - for (std::unique_ptr& communicator : communicators_) { + for (std::unique_ptr& communicator : communicators) { num_comms += communicator->members.size(); } BlockingCounter pending(num_comms); - for (std::unique_ptr& communicator : communicators_) { - for (const CommunicatorMember& member : communicator->members) { + for (std::unique_ptr& communicator : communicators) { + for (CommunicatorMember& member : communicator->members) { queue.Schedule([&member, &pending]() { ncclCommAbort(member.nccl_comm); + member.nccl_comm = nullptr; pending.DecrementCount(); }); } @@ -894,6 +899,12 @@ void NcclManager::StartAbort(const Status& s) { pending.Wait(); } +void NcclManager::Reset() { + mutex_lock l(mu_); + status_ = Status(); + VLOG(2) << "Reset NcclManager " << this; +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 88b8bc85663..b1d9dd62f94 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -195,6 +195,10 @@ class NcclManager { // launched with this NcclManager. void StartAbort(const Status& s); + // Resets a previously aborted NcclManager, making it available for future + // collectives. + void Reset(); + private: enum CollectiveType { kAllReduce = 1, @@ -248,7 +252,7 @@ class NcclManager { absl::flat_hash_map> device_to_comm_streams_ TF_GUARDED_BY(mu_); - std::vector> communicators_; + std::vector> communicators_ TF_GUARDED_BY(mu_); Status status_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index d16eefa6f72..0d0d003d63f 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -17,8 +17,6 @@ limitations under the License. #include "absl/strings/str_format.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/nccl/nccl_manager.h" - #include #include #include @@ -27,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/nccl/nccl_manager.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/unbounded_work_queue.h" @@ -640,6 +639,10 @@ TEST(NcclManagerTest, CommunicatorKey) { } #if !TENSORFLOW_USE_ROCM +// ROCm platform currently does not support simulating a mutli-node +// environment, on a single node with multiple GPUS. So tests that rely +// upon such simulation need to be skipped on the ROCm platform + // This test creates `num_nodes` NcclManagers to simulate a multi-node // environment. It works on a single node with multiple GPUs. It enqueues NCCL // kernels on separate stream per rank. @@ -661,6 +664,10 @@ TYPED_TEST(NcclManagerTest, MultiNodeSingle) { } #if !TENSORFLOW_USE_ROCM +// ROCm platform currently does not support simulating a mutli-node +// environment, on a single node with multiple GPUS. So tests that rely +// upon such simulation need to be skipped on the ROCm platform + // Multi-node broadcast. TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) { int num_nodes; @@ -850,20 +857,24 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { this->VerifyError(test_case.get()); } -TYPED_TEST(NcclManagerTest, Abort) { +#if !TENSORFLOW_USE_ROCM +// ROCm platform currently does not support simulating a mutli-node +// environment, on a single node with multiple GPUS. So tests that rely +// upon such simulation need to be skipped on the ROCm platform + +TYPED_TEST(NcclManagerTest, AbortThenReset) { using NodeState = typename TestFixture::NodeState; using TestCase = typename TestFixture::TestCase; - int num_nodes = 2; + const int num_nodes = 2; std::vector nodes(num_nodes); // First do a normal all-reduce to simulate the the case when there're // multiple communicators. this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1); - // Use a new communicator_key, which uses a new set of ncclComm underneath. - string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); - string collective_key = "allreduce"; + const string collective_key = "allreduce"; ncclRedOp_t reduction_op = static_cast(0); - auto node_fn = [&](TestCase* test_case, int node) { + auto node_fn = [&](TestCase* test_case, int node, + const string& communicator_key) { auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node, /* local_rank */ 0); auto* info = device->tensorflow_gpu_device_info(); @@ -881,6 +892,8 @@ TYPED_TEST(NcclManagerTest, Abort) { nodes[node].nccl_manager.SignalMultiNodeReady(collective_key); }; + // Use a new communicator_key, which uses a new set of ncclComm underneath. + string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); // Do a normal all-reduce with this communicator key to initialize ncclComm. // This is because ncclCommInitRank waits for all ranks and is blocking. { @@ -890,7 +903,9 @@ TYPED_TEST(NcclManagerTest, Abort) { TensorShape({2, 3}), 0.0f)); for (int i = 0; i < num_nodes; ++i) { this->work_queue_->Schedule( - [&node_fn, &test_case, i]() { node_fn(test_case.get(), i); }); + [&node_fn, &test_case, i, communicator_key]() { + node_fn(test_case.get(), i, communicator_key); + }); } this->VerifyResults(test_case.get()); } @@ -901,17 +916,43 @@ TYPED_TEST(NcclManagerTest, Abort) { this->MakeReductionTestCase( /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, TensorShape({2, 3}), 0.0f)); - node_fn(test_case.get(), 0); + node_fn(test_case.get(), 0, communicator_key); Env::Default()->SleepForMicroseconds(1000000); - nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down")); + for (auto& node : nodes) { + node.nccl_manager.StartAbort(errors::Unavailable("peer down")); + } { mutex_lock l(test_case->mu); while (test_case->num_completed != 1) { test_case->done_cv.wait(l); } } + + // Reset the aborted NcclManager and then run another all-reduce with the + // resetted NcclManagers. + for (auto& node : nodes) { + node.nccl_manager.Reset(); + } + // Regenerate the communicator_key, because this is needed to create new + // communicators. + communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey(); + { + std::unique_ptr test_case( + this->MakeReductionTestCase( + /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op, + TensorShape({2, 3}), 0.0f)); + for (int i = 0; i < num_nodes; ++i) { + this->work_queue_->Schedule( + [&node_fn, &test_case, i, communicator_key]() { + node_fn(test_case.get(), i, communicator_key); + }); + } + this->VerifyResults(test_case.get()); + } } +#endif + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/ops/compat/ops_history_v2/RaggedTensorToVariantGradient.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RaggedTensorToVariantGradient.pbtxt new file mode 100644 index 00000000000..45f2fcefe04 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/RaggedTensorToVariantGradient.pbtxt @@ -0,0 +1,36 @@ +op { + name: "RaggedTensorToVariantGradient" + input_arg { + name: "encoded_ragged_grad" + type: DT_VARIANT + } + input_arg { + name: "row_splits" + type_attr: "Tsplits" + } + input_arg { + name: "dense_values_shape" + type: DT_INT32 + } + output_arg { + name: "dense_values_grad" + type_attr: "Tvalues" + } + attr { + name: "Tvalues" + type: "type" + } + attr { + name: "Tsplits" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 91bcc3be49a..a19c3a6d934 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -69,7 +69,7 @@ REGISTER_OP("EmptyTensorList") 0, &element_shape)); c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -106,7 +106,7 @@ REGISTER_OP("TensorListPushBack") } c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -153,7 +153,7 @@ REGISTER_OP("TensorListPushBackBatch") } c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -345,7 +345,7 @@ REGISTER_OP("TensorListSplit") &element_shape_from_tensor_shape)); c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 250dfc49f60..dd831e03d45 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -34775,6 +34775,42 @@ op { type: "bool" } } +op { + name: "RaggedTensorToVariantGradient" + input_arg { + name: "encoded_ragged_grad" + type: DT_VARIANT + } + input_arg { + name: "row_splits" + type_attr: "Tsplits" + } + input_arg { + name: "dense_values_shape" + type: DT_INT32 + } + output_arg { + name: "dense_values_grad" + type_attr: "Tvalues" + } + attr { + name: "Tvalues" + type: "type" + } + attr { + name: "Tsplits" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "RandomCrop" input_arg { diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index 44712bf7739..043ff469487 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes( Status RaggedTensorToSparseShapeFn(InferenceContext* c); Status RaggedTensorToVariantShapeFn(InferenceContext* c); Status RaggedTensorFromVariantShapeFn(InferenceContext* c); -tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c); +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); +Status RaggedTensorToTensorShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant") .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedTensorFromVariantShapeFn); +REGISTER_OP("RaggedTensorToVariantGradient") + .Input("encoded_ragged_grad: variant") + .Input("row_splits: Tsplits") + .Input("dense_values_shape: int32") + .Output("dense_values_grad: Tvalues") + .Attr("Tvalues: type") + .Attr("Tsplits: {int32, int64} = DT_INT64") + .SetShapeFn(RaggedTensorToVariantGradientShapeFn); + REGISTER_OP("RaggedTensorToTensor") .Attr("T: type") .Attr("Tindex: {int64, int32}") @@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) { return Status::OK(); } +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { + ShapeHandle shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape)); + c->set_output(0, shape); + return Status::OK(); +} + Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { int64 input_ragged_rank; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 849048e99be..2f94bce8c5b 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -1,7 +1,7 @@ # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("//tensorflow:tensorflow.bzl", "filegroup") -load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_copts") +load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test", "tf_copts") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( @@ -429,12 +429,28 @@ cc_library( deps = [ "//tensorflow/core/platform", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:str_util", "//tensorflow/core/platform:types", "//tensorflow/core/util:reporter", ], ) +tf_cc_test( + name = "test_benchmark_test", + srcs = ["test_benchmark_test.cc"], + tags = [ + "nobuilder", + "notap", + ], + deps = [ + ":test_benchmark", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "test", testonly = True, diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc index 533c4ac1df1..3b495b0ad00 100644 --- a/tensorflow/core/platform/default/test_benchmark.cc +++ b/tensorflow/core/platform/default/test_benchmark.cc @@ -52,18 +52,48 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int, int)) Register(); } +Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&)) + : name_(name), + // -1 because the number of parameters is not part of the benchmark + // routine signature. + num_args_(-1), + fn_state_(fn) { + Register(); +} + +void Benchmark::CheckArgCount(int expected) { + if (num_args_ == expected) return; + + // Number of args is not part of function signature. + // Verify that if benchmark instantiation has previously provided args, they + // match "args". + if (num_args_ < 0) { + if (args_.empty() || instantiated_num_args_ == expected) return; + } + CHECK(false) << "Expected " << expected << " args for benchmark, but got " + << instantiated_num_args_; +} + Benchmark* Benchmark::Arg(int x) { - CHECK_EQ(num_args_, 1); + CheckArgCount(/*expected=*/1); args_.push_back(std::make_pair(x, -1)); + instantiated_num_args_ = 1; return this; } Benchmark* Benchmark::ArgPair(int x, int y) { - CHECK_EQ(num_args_, 2); + CheckArgCount(/*expected=*/2); + instantiated_num_args_ = 2; args_.push_back(std::make_pair(x, y)); return this; } +Benchmark* Benchmark::UseRealTime() { + // Do nothing. + // This only exists for API compatibility with internal benchmarks. + return this; +} + namespace { void AddRange(std::vector* dst, int lo, int hi, int mult) { @@ -210,6 +240,7 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { static const int64 kMaxIters = 1000000000; static const double kMinTime = 0.5; int64 iters = kMinIters; + while (true) { accum_time = 0; start_time = env->NowMicros(); @@ -220,8 +251,11 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { (*fn0_)(iters); } else if (fn1_) { (*fn1_)(iters, arg1); - } else { + } else if (fn2_) { (*fn2_)(iters, arg1, arg2); + } else if (fn_state_) { + ::testing::benchmark::State state(iters, std::vector(arg1, arg2)); + (*fn_state_)(state); } StopTiming(); const double seconds = accum_time * 1e-6; @@ -261,3 +295,38 @@ void UseRealTime() {} } // namespace testing } // namespace tensorflow + +namespace testing { +namespace benchmark { +State::State(size_t max_iterations, const std::vector& args) + : max_iterations(max_iterations), args_(args) { + completed_iterations_ = 0; +} + +void State::PauseTiming() { ::tensorflow::testing::StopTiming(); } + +void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); } + +void State::SetBytesProcessed(::tensorflow::int64 bytes) { + ::tensorflow::testing::BytesProcessed(bytes); +} + +void State::SetItemsProcessed(::tensorflow::int64 items) { + ::tensorflow::testing::ItemsProcessed(items); +} + +void State::SetLabel(absl::string_view label) { + ::tensorflow::testing::SetLabel(std::string(label)); +} + +int State::range(size_t i) const { + if (i >= args_.size()) { + LOG(FATAL) << "argument for range " << i << " is not set"; + } + return args_[i]; +} + +void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); } + +} // namespace benchmark +} // namespace testing diff --git a/tensorflow/core/platform/default/test_benchmark.h b/tensorflow/core/platform/default/test_benchmark.h index 55149e5c050..4a6892cb8c5 100644 --- a/tensorflow/core/platform/default/test_benchmark.h +++ b/tensorflow/core/platform/default/test_benchmark.h @@ -32,6 +32,12 @@ limitations under the License. #define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c) #define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c +namespace testing { +namespace benchmark { +class State; +} +} // namespace testing + namespace tensorflow { namespace testing { @@ -77,26 +83,41 @@ void DoNotOptimize(const T& var) { class Benchmark { public: - Benchmark(const char* name, void (*fn)(int)); - Benchmark(const char* name, void (*fn)(int, int)); - Benchmark(const char* name, void (*fn)(int, int, int)); + [[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name, + void (*fn)(int)); + + [[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name, + void (*fn)(int, + int)); + + [[deprecated("use `benchmark::State&` instead.")]] Benchmark( + const char* name, void (*fn)(int, int, int)); + + Benchmark(const char* name, void (*fn)(::testing::benchmark::State&)); Benchmark* Arg(int x); Benchmark* ArgPair(int x, int y); Benchmark* Range(int lo, int hi); Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2); + + Benchmark* UseRealTime(); + static void Run(const char* pattern); private: string name_; int num_args_; + int instantiated_num_args_ = -1; std::vector > args_; void (*fn0_)(int) = nullptr; void (*fn1_)(int, int) = nullptr; void (*fn2_)(int, int, int) = nullptr; + void (*fn_state_)(::testing::benchmark::State&) = nullptr; void Register(); void Run(int arg1, int arg2, int* run_count, double* run_seconds); + + void CheckArgCount(int expected); }; void RunBenchmarks(); @@ -110,4 +131,148 @@ void UseRealTime(); } // namespace testing } // namespace tensorflow +// Support `void BM_Func(benchmark::State&)` interface so that the it is +// compatible with the internal version. +namespace testing { +namespace benchmark { +// State is passed as an argument to a benchmark function. +// Each thread in threaded benchmarks receives own object. +class State { + public: + // Incomplete iterator-like type with dummy value type so that + // benchmark::State can support iteration with a range-based for loop. + // + // The only supported usage: + // + // static void BM_Foo(benchmark::State& state) { + // for (auto s : state) { + // // perform single iteration + // } + // } + // + // This is meant to replace the deprecated API : + // + // static void BM_Foo(int iters) { + // while (iters-- > 0) { + // // perform single iteration + // } + // } + // + // See go/benchmark#old-benchmark-interface for more details. + class Iterator { + public: + struct Value { + // Non-trivial destructor to avoid warning for unused dummy variable in + // the range-based for loop. + ~Value() {} + }; + + explicit Iterator(State* parent); + + Iterator& operator++(); + + bool operator!=(const Iterator& other); + + Value operator*(); + + private: + State* const parent_; + }; + + Iterator begin(); + Iterator end(); + + void PauseTiming(); + void ResumeTiming(); + + // Set the number of bytes processed by the current benchmark + // execution. This routine is typically called once at the end of a + // throughput oriented benchmark. If this routine is called with a + // value > 0, then bytes processed per second is also reported. + void SetBytesProcessed(::tensorflow::int64 bytes); + + // If this routine is called with items > 0, then an items/s + // label is printed on the benchmark report line for the currently + // executing benchmark. It is typically called at the end of a processing + // benchmark where a processing items/second output is desired. + void SetItemsProcessed(::tensorflow::int64 items); + + // If this method is called, the specified label is printed at the + // end of the benchmark report line for the currently executing + // benchmark. Example: + // static void BM_Compress(benchmark::State& state) { + // ... + // double compression = input_size / output_size; + // state.SetLabel(StringPrintf("compress:%.1f%%", 100.0*compression)); + // } + // Produces output that looks like: + // BM_Compress 50 50 14115038 compress:27.3% + // + // REQUIRES: a benchmark is currently executing + void SetLabel(absl::string_view label); + + // For parameterized benchmarks, range(i) returns the value of the ith + // parameter. Simple benchmarks are not parameterized and do not need to call + // range(). + int range(size_t i) const; + + // Total number of iterations processed so far. + size_t iterations() const; + + const size_t + max_iterations; // NOLINT: for compatibility with OSS benchmark library + + // Disallow copy and assign. + State(const State&) = delete; + State& operator=(const State&) = delete; + + protected: + friend class tensorflow::testing::Benchmark; + State(size_t max_iterations, const std::vector& args); + + private: + size_t completed_iterations_; + std::vector args_; +}; + +inline State::Iterator::Iterator(State* parent) : parent_(parent) {} + +inline size_t State::iterations() const { return completed_iterations_; } + +inline bool State::Iterator::operator!=(const Iterator& other) { + DCHECK_EQ(other.parent_, nullptr); + DCHECK_NE(parent_, nullptr); + + if (parent_->completed_iterations_ < parent_->max_iterations) { + return true; + } + + ++parent_->completed_iterations_; + // If this is the last iteration, stop the timer. + parent_->PauseTiming(); + return false; +} + +inline State::Iterator& State::Iterator::operator++() { + DCHECK_LT(parent_->completed_iterations_, parent_->max_iterations); + ++parent_->completed_iterations_; + return *this; +} + +inline State::Iterator::Value State::Iterator::operator*() { return Value(); } + +inline State::Iterator State::begin() { + // Starts the timer here because if the code uses this API, it expects + // the timer to starts at the beginning of this loop. + ResumeTiming(); + return Iterator(this); +} + +inline State::Iterator State::end() { return Iterator(nullptr); } + +void RunSpecifiedBenchmarks(); + +} // namespace benchmark +} // namespace testing + #endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/default/test_benchmark_test.cc b/tensorflow/core/platform/default/test_benchmark_test.cc new file mode 100644 index 00000000000..2c692b2af7a --- /dev/null +++ b/tensorflow/core/platform/default/test_benchmark_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/default/test_benchmark.h" + +// Test the new interface: BM_benchmark(benchmark::State& state) +namespace tensorflow { +namespace testing { +namespace { + +void BM_TestIterState(::testing::benchmark::State& state) { + int i = 0; + for (auto s : state) { + ++i; + DoNotOptimize(i); + } +} + +BENCHMARK(BM_TestIterState); + +} // namespace +} // namespace testing +} // namespace tensorflow + +int main() { + ::testing::benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 74195db7730..a92e66e67d0 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -164,9 +164,8 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) { string nn(namenode); string cacheKey(scheme.data(), scheme.size()); - hdfsBuilder* builder = libhdfs()->hdfsNewBuilder(); if (scheme == "file") { - libhdfs()->hdfsBuilderSetNameNode(builder, nullptr); + nn = ""; } else if (scheme == "viewfs") { char* defaultFS = nullptr; libhdfs()->hdfsConfGetStr("fs.defaultFS", &defaultFS); @@ -181,19 +180,21 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) { // The default NameNode configuration will be used (from the XML // configuration files). See: // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 - libhdfs()->hdfsBuilderSetNameNode(builder, "default"); + nn = "default"; } else if (scheme == "har") { TF_RETURN_IF_ERROR(SplitArchiveNameAndPath(path, nn)); - libhdfs()->hdfsBuilderSetNameNode(builder, nn.c_str()); - cacheKey += nn; } else { - libhdfs()->hdfsBuilderSetNameNode(builder, - nn.empty() ? "default" : nn.c_str()); - cacheKey += nn; + if (nn.empty()) { + nn = "default"; + } } + cacheKey += nn; { mutex_lock lock(mu_); if (connectionCache_.find(cacheKey) == connectionCache_.end()) { + hdfsBuilder* builder = libhdfs()->hdfsNewBuilder(); + libhdfs()->hdfsBuilderSetNameNode(builder, + nn.empty() ? nullptr : nn.c_str()); hdfsFS cacheFs = libhdfs()->hdfsBuilderConnect(builder); if (cacheFs == nullptr) { return errors::NotFound(strerror(errno)); diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc index 7cd1c4de88f..b76b3377397 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils.cc +++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc @@ -23,6 +23,10 @@ limitations under the License. #include #endif +#if defined(__APPLE__) +#include +#endif + #include "absl/base/call_once.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h" @@ -114,17 +118,11 @@ static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr; "CPU frequency"; return INVALID_FREQUENCY; #elif defined(__APPLE__) - int64 freq_hz; - FILE* fp = - popen("sysctl hw | grep hw.cpufrequency_max: | cut -d' ' -f 2", "r"); - if (fp == nullptr) { - return INVALID_FREQUENCY; - } - if (fscanf(fp, "%lld", &freq_hz) != 1) { - return INVALID_FREQUENCY; - } - pclose(fp); - if (freq_hz < 1e6) { + int64 freq_hz = 0; + size_t freq_size = sizeof(freq_hz); + int retval = + sysctlbyname("hw.cpufrequency_max", &freq_hz, &freq_size, NULL, 0); + if (retval != 0 || freq_hz < 1e6) { LOG(WARNING) << "Failed to get CPU frequency: " << freq_hz << " Hz"; return INVALID_FREQUENCY; } diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 38f5da03838..57f827956df 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -356,6 +356,7 @@ tf_cc_test( size = "small", srcs = ["xplane_to_op_stats_test.cc"], deps = [ + ":step_events_to_steps_db", ":xplane_to_op_stats", ":xplane_to_tf_functions", "//tensorflow/core:lib", @@ -684,17 +685,23 @@ cc_library( deps = [ ":op_stats_to_input_pipeline_analysis", ":op_stats_to_overview_page", + ":op_stats_to_pod_viewer", ":op_stats_to_tf_stats", ":xplane_to_memory_profile", ":xplane_to_op_stats", + ":xplane_to_tf_data_stats", ":xplane_to_trace_events", "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", + "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", + "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", ], ) @@ -705,11 +712,13 @@ cc_library( hdrs = ["xplane_to_tf_data_stats.h"], copts = tf_profiler_copts(), deps = [ + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/platform:protobuf", "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:group_events", + "//tensorflow/core/profiler/utils:html_utils", "//tensorflow/core/profiler/utils:tf_op_utils", "//tensorflow/core/profiler/utils:tf_xplane_visitor", "//tensorflow/core/profiler/utils:timespan", @@ -718,6 +727,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc index 5c492a6fa44..0de30f2c4e8 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner.cc @@ -97,6 +97,8 @@ void CombineRunEnvironment(const RunEnvironment& src, RunEnvironment* dst) { dst->set_num_cores_per_replica( std::max(src.num_cores_per_replica(), dst->num_cores_per_replica())); *dst->mutable_topology() = src.topology(); + } else if (dst->device_type().empty()) { + dst->set_device_type(src.device_type()); } dst->set_task_count(src.task_count() + dst->task_count()); (*dst->mutable_host_independent_job_info()) = src.host_independent_job_info(); diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index cf0b7e6ad43..b8610eb903a 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -72,6 +72,9 @@ void ComputeHostTips(OverviewPageRecommendation* re) { *re->add_host_tips() = MakeOverviewPageTip( "input_pipeline_analyzer (especially Section 3 for the breakdown of " "input operations on the Host)"); + *re->add_host_tips() = MakeOverviewPageTip( + "tf_data_bottleneck_analysis (find the bottleneck in the tf.data input " + "pipeline)"); *re->add_host_tips() = MakeOverviewPageTip( "trace_viewer (look at the activities on the timeline of each Host " "Thread near the bottom of the trace view)"); diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc index 067d47a0b57..0a1a1e19048 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc @@ -32,7 +32,7 @@ namespace tensorflow { namespace profiler { namespace { -XEventBuilder AddTensorFlowOpEvent(absl::string_view tf_op_fullname, +XEventBuilder AddTensorFlowOpEvent(std::string&& tf_op_fullname, int64 start_timestamp_ns, int64 duration_ns, bool on_device, absl::string_view kernel_name, @@ -42,12 +42,13 @@ XEventBuilder AddTensorFlowOpEvent(absl::string_view tf_op_fullname, event.SetTimestampNs(start_timestamp_ns); event.SetDurationNs(duration_ns); if (!on_device) return event; - event.ParseAndAddStatValue(*plane->GetOrCreateStatMetadata("level 0"), - tf_op_fullname); + event.AddStatValue( + *plane->GetOrCreateStatMetadata("level 0"), + *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); return event; } -void AddTensorFlowOpEventWithKernelDetails(absl::string_view tf_op_fullname, +void AddTensorFlowOpEventWithKernelDetails(std::string&& tf_op_fullname, int64 start_timestamp_ns, int64 duration_ns, bool on_device, absl::string_view kernel_name, @@ -55,8 +56,8 @@ void AddTensorFlowOpEventWithKernelDetails(absl::string_view tf_op_fullname, XPlaneBuilder* plane, XLineBuilder* line) { XEventBuilder event = - AddTensorFlowOpEvent(tf_op_fullname, start_timestamp_ns, duration_ns, - on_device, kernel_name, plane, line); + AddTensorFlowOpEvent(std::move(tf_op_fullname), start_timestamp_ns, + duration_ns, on_device, kernel_name, plane, line); if (!on_device) return; event.ParseAndAddStatValue(*plane->GetOrCreateStatMetadata("kernel_details"), kernel_details); diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc index 581a003eb38..4fe3ed58366 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc @@ -22,10 +22,7 @@ limitations under the License. namespace tensorflow { namespace profiler { -void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) { - VLOG(3) << "Post processing local profiler XSpace."; - // Post processing the collected XSpace without hold profiler lock. - // 1. Merge plane of host events with plane of CUPTI driver api. +void MergeHostPlanes(XSpace* space) { const XPlane* cupti_driver_api_plane = FindPlaneWithName(*space, kCuptiDriverApiPlaneName); const XPlane* python_tracer_plane = @@ -40,15 +37,20 @@ void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) { MergePlanes(*python_tracer_plane, host_plane); } SortXLinesBy(host_plane, XLinesComparatorByName()); - // NOTE: RemovePlaneWithName might invalidate plane pointers. so do these - // at the last step. if (cupti_driver_api_plane) { - RemovePlaneWithName(space, kCuptiDriverApiPlaneName); + RemovePlane(space, cupti_driver_api_plane); } if (python_tracer_plane) { - RemovePlaneWithName(space, kPythonTracerPlaneName); + RemovePlane(space, python_tracer_plane); } } +} + +void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) { + VLOG(3) << "Post processing local profiler XSpace."; + // Post processing the collected XSpace without hold profiler lock. + // 1. Merge plane of host events with plane of CUPTI driver api. + MergeHostPlanes(space); // 2. Normalize all timestamps by shifting timeline to profiling start time. // NOTE: this have to be done before sorting XSpace due to timestamp overflow. diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h index 31ebe28c48f..70c6785591b 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.h +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.h @@ -21,6 +21,9 @@ limitations under the License. namespace tensorflow { namespace profiler { +// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer. +void MergeHostPlanes(XSpace* space); + // Post process XSpaces collected locally from multiple profilers. void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index bdac1129c81..7d6f23db041 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -31,7 +31,7 @@ namespace tensorflow { namespace profiler { namespace { -void AddTensorFlowOpEvent(absl::string_view tf_op_fullname, +void AddTensorFlowOpEvent(std::string&& tf_op_fullname, int64 start_timestamp_ns, int64 duration_ns, bool on_device, absl::string_view kernel_name, XPlaneBuilder* plane, XLineBuilder* line) { @@ -40,8 +40,9 @@ void AddTensorFlowOpEvent(absl::string_view tf_op_fullname, event.SetTimestampNs(start_timestamp_ns); event.SetDurationNs(duration_ns); if (!on_device) return; - event.ParseAndAddStatValue(*plane->GetOrCreateStatMetadata("level 0"), - tf_op_fullname); + event.AddStatValue( + *plane->GetOrCreateStatMetadata("level 0"), + *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); } TEST(ConvertXPlaneToOpMetricsDb, HostOpMetricsDb) { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 9ca784c01bb..6eb67eab216 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -63,7 +63,7 @@ DeviceCapabilities GetDeviceCapFromXPlane(const XPlane& device_plane) { cap.set_num_cores(stat.IntValue()); break; case kDevCapMemoryBandwidth: - cap.set_memory_bandwidth(stat.IntValue()); // bytes/s + cap.set_memory_bandwidth(stat.UintValue()); // bytes/s break; case kDevCapMemorySize: cap.set_memory_size_in_bytes(stat.UintValue()); @@ -100,10 +100,15 @@ PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { namespace { -void SetRunEnvironment(int32 accelerator_count, RunEnvironment* env) { +void SetRunEnvironment(const XSpace& space, int32 accelerator_count, + RunEnvironment* env) { // Currently, we only support profiling one host and one program. env->set_host_count(1); env->set_task_count(1); + for (const auto& hostname : space.hostnames()) { + std::vector hostname_split = absl::StrSplit(hostname, ':'); + (*env->mutable_hostnames())[hostname_split[0]] = true; + } env->set_device_type(accelerator_count > 0 ? "GPU" : "CPU"); env->set_device_core_count(accelerator_count); } @@ -155,7 +160,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, // Convert device planes. OpMetricsDbCombiner op_metrics_db_combiner( op_stats.mutable_device_op_metrics_db()); - SetRunEnvironment(device_planes.size(), op_stats.mutable_run_environment()); + SetRunEnvironment(space, device_planes.size(), + op_stats.mutable_run_environment()); KernelReportMap reports; // TODO(b/161942993) parallelize XPlane processing per thread. @@ -201,6 +207,11 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, *op_stats.mutable_device_op_metrics_db()->mutable_precision_stats() = ComputePrecisionStats(nonoverlapped_step_events); } + + CoreDetails& details = + (*op_stats.mutable_core_id_to_details())[kDefaultGpuLocalCoreId]; + details.set_hostname(space.hostnames().empty() ? "localhost" + : space.hostnames(0)); return op_stats; } diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index 09ed246e766..178f8c261f2 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -33,7 +33,7 @@ struct OpStatsOptions { // NOTE: call GroupTfEvents before if OpStats.step_db needs to be generated. OpStats ConvertXSpaceToOpStats(const XSpace& space, - const OpStatsOptions& config); + const OpStatsOptions& options); // Propagate and dedup the diagnostics in XSpace and add to OpStats. void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index c3ccb73c078..a61c22f98a4 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" #include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" @@ -36,33 +37,34 @@ namespace tensorflow { namespace profiler { namespace { +static constexpr char kXPlanePb[] = "xplane.pb"; + TEST(ConvertXPlaneToOpStats, PerfEnv) { XSpace space; constexpr double kMaxError = 0.01; constexpr int kClockRateKHz = 1530000; constexpr int kCoreCount = 80; - constexpr uint64 kMemoryBandwidthBytesPerSecond = 900 * 1e9; + constexpr uint64 kMemoryBandwidthBytesPerSecond = + uint64{900} * 1000 * 1000 * 1000; // Volta. constexpr int kComputeCapMajor = 7; constexpr int kComputeCapMinor = 0; XPlaneBuilder device_plane( GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0)); - device_plane.ParseAndAddStatValue( - *device_plane.GetOrCreateStatMetadata("clock_rate"), - absl::StrCat(kClockRateKHz)); - device_plane.ParseAndAddStatValue( - *device_plane.GetOrCreateStatMetadata("core_count"), - absl::StrCat(kCoreCount)); - device_plane.ParseAndAddStatValue( + device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("clock_rate"), + kClockRateKHz); + device_plane.AddStatValue(*device_plane.GetOrCreateStatMetadata("core_count"), + kCoreCount); + device_plane.AddStatValue( *device_plane.GetOrCreateStatMetadata("memory_bandwidth"), - absl::StrCat(kMemoryBandwidthBytesPerSecond)); - device_plane.ParseAndAddStatValue( + kMemoryBandwidthBytesPerSecond); + device_plane.AddStatValue( *device_plane.GetOrCreateStatMetadata("compute_cap_major"), - absl::StrCat(kComputeCapMajor)); - device_plane.ParseAndAddStatValue( + kComputeCapMajor); + device_plane.AddStatValue( *device_plane.GetOrCreateStatMetadata("compute_cap_minor"), - absl::StrCat(kComputeCapMinor)); + kComputeCapMinor); GroupTfEvents(&space); OpStatsOptions options; @@ -178,9 +180,20 @@ TEST(ConvertXPlaneToOpStats, PropagateAndDedupErrors) { EXPECT_EQ(kError, op_stats.diagnostics().errors(/*index=*/0)); } +TEST(ConvertXPlaneToOpStats, Hostnames) { + XSpace space; + static constexpr char kHost[] = "host1"; + *space.add_hostnames() = kHost; + + OpStats op_stats = ConvertXSpaceToOpStats(space, OpStatsOptions()); + EXPECT_EQ( + kHost, + op_stats.core_id_to_details().at(kDefaultGpuLocalCoreId).hostname()); +} + // Helper function to build a XSpace and store it to test directory. -void BuildAndStoreXSpaceForTest(Env* test_env, const std::string& test_dir, - const std::string& xspace_name) { +void BuildAndStoreXSpaceForTest(Env* test_env, absl::string_view test_dir, + absl::string_view hostname) { constexpr int64 kStepNum = 123; constexpr int64 kStepId = 456; // Create a host only XSpace for test. @@ -202,6 +215,9 @@ void BuildAndStoreXSpaceForTest(Env* test_env, const std::string& test_dir, CreateXEvent(&host_plane_builder, &executor_thread, "aaa:bbb", 30, 70); GroupTfEvents(&xspace); + xspace.add_hostnames(std::string(hostname)); + + std::string xspace_name = absl::StrCat(hostname, ".", kXPlanePb); TF_CHECK_OK( WriteBinaryProto(test_env, io::JoinPath(test_dir, xspace_name), xspace)) << "Failed to write binary XSpace to file: " << xspace_name; @@ -214,14 +230,17 @@ TEST(ConvertXPlaneToOpStats, TestConvertMultiXSpacesToCombinedOpStats) { TF_CHECK_OK(test_env->CreateDir(test_dir)) << "Failed to create test directory: " << test_dir; - const std::string xspace1 = "xspace1.pb"; - const std::string xspace2 = "xspace2.pb"; - BuildAndStoreXSpaceForTest(test_env, test_dir, xspace1); - BuildAndStoreXSpaceForTest(test_env, test_dir, xspace2); + static constexpr char kHost1[] = "host1"; + static constexpr char kHost2[] = "host2"; + + BuildAndStoreXSpaceForTest(test_env, test_dir, kHost1); + BuildAndStoreXSpaceForTest(test_env, test_dir, kHost2); std::vector xspace_paths; - xspace_paths.push_back(io::JoinPath(test_dir, xspace1)); - xspace_paths.push_back(io::JoinPath(test_dir, xspace2)); + xspace_paths.push_back( + io::JoinPath(test_dir, absl::StrCat(kHost1, ".", kXPlanePb))); + xspace_paths.push_back( + io::JoinPath(test_dir, absl::StrCat(kHost2, ".", kXPlanePb))); OpStatsOptions options; options.generate_op_metrics_db = true; options.generate_step_db = true; @@ -248,8 +267,13 @@ TEST(ConvertXPlaneToOpStats, TestConvertMultiXSpacesToCombinedOpStats) { const auto& step_info_per_core = combined_op_stats.step_db().step_sequence(0).step_info_per_core(); // global_core_id is computed using: 1000 * host_id + local_core_id. - EXPECT_TRUE(step_info_per_core.contains(1)); - EXPECT_TRUE(step_info_per_core.contains(1001)); + EXPECT_TRUE(step_info_per_core.contains(kDefaultGpuLocalCoreId)); + EXPECT_TRUE(step_info_per_core.contains(1000 + kDefaultGpuLocalCoreId)); + + const auto& core_details_map = combined_op_stats.core_id_to_details(); + EXPECT_EQ(kHost1, core_details_map.at(kDefaultGpuLocalCoreId).hostname()); + EXPECT_EQ(kHost2, + core_details_map.at(1000 + kDefaultGpuLocalCoreId).hostname()); // Tear down environment and directory for testing. int64 undeleted_files, undeleted_dirs; diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc index 10b2122d764..cf48a7d61ab 100644 --- a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc +++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc @@ -50,6 +50,7 @@ const absl::string_view kInputPipeline = "input_pipeline"; const absl::string_view kOverviewPage = "overview_page"; const absl::string_view kKernelStats = "kernel_stats"; const absl::string_view kMemoryProfile = "memory_profile"; +const absl::string_view kXPlanePb = "xplane.pb"; template void AddToolData(absl::string_view tool_name, const Proto& tool_output, @@ -74,6 +75,9 @@ Status ConvertXSpaceToProfileResponse(const XSpace& xspace, absl::flat_hash_set tools(req.tools().begin(), req.tools().end()); if (tools.empty()) return Status::OK(); + if (tools.contains(kXPlanePb)) { + AddToolData(kXPlanePb, xspace, response); + } if (tools.contains(kTraceViewer)) { Trace trace; ConvertXSpaceToTraceEvents(xspace, &trace); diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc index d50cd9a98ff..c905713a4d8 100644 --- a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc @@ -39,16 +39,16 @@ void CreateXSpace(XSpace* space) { thread1.AddEvent(*host_plane.GetOrCreateEventMetadata("event1")); event1.SetTimestampNs(150000); event1.SetDurationNs(10000); - event1.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), - "Relu"); + event1.AddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), + *host_plane.GetOrCreateStatMetadata("Relu")); XLineBuilder thread2 = host_plane.GetOrCreateLine(20); thread2.SetName("thread2"); XEventBuilder event2 = thread2.AddEvent(*host_plane.GetOrCreateEventMetadata("event2")); event2.SetTimestampNs(160000); event2.SetDurationNs(10000); - event2.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), - "Conv2D"); + event2.AddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), + *host_plane.GetOrCreateStatMetadata("Conv2D")); device_plane.SetName("gpu:0"); device_plane.SetId(1); @@ -58,8 +58,8 @@ void CreateXSpace(XSpace* space) { stream1.AddEvent(*device_plane.GetOrCreateEventMetadata("kernel1")); event3.SetTimestampNs(180000); event3.SetDurationNs(10000); - event3.ParseAndAddStatValue( - *device_plane.GetOrCreateStatMetadata("correlation id"), "55"); + event3.AddStatValue(*device_plane.GetOrCreateStatMetadata("correlation id"), + 55); } TEST(ConvertXPlaneToProfileResponse, TraceViewer) { @@ -109,6 +109,18 @@ TEST(ConvertXPlaneToProfileResponse, TensorflowStats) { ASSERT_TRUE(tf_stats_db.ParseFromString(response.tool_data(0).data())); } +TEST(ConvertXPlaneToProfileResponse, XPlane) { + XSpace xspace; + CreateXSpace(&xspace); + ProfileRequest request; + request.add_tools("xplane.pb"); + ProfileResponse response; + TF_CHECK_OK(ConvertXSpaceToProfileResponse(xspace, request, &response)); + EXPECT_EQ(1, response.tool_data_size()); + EXPECT_EQ("xplane.pb", response.tool_data(0).name()); + ASSERT_TRUE(xspace.ParseFromString(response.tool_data(0).data())); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc index 7abc8ac37e3..af77773f78d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -17,11 +17,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/utils/group_events.h" +#include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/tf_op_utils.h" #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/timespan.h" @@ -30,6 +33,10 @@ limitations under the License. namespace tensorflow { namespace profiler { + +// 50 us from https://www.tensorflow.org/guide/data_performance_analysis +const int64 kSlowCallThresholdPs = 50 * 1000000; + namespace { // Returns true if the given iterator event is for a root iterator. @@ -129,7 +136,7 @@ void ProcessEventForest(const EventForest& event_forest, } } -void SetInputPipelineMetadata(int64 id, uint64 name_id, +void SetInputPipelineMetadata(int64 id, int64 name_id, bool is_device_input_pipeline, InputPipelineMetadata* metadata) { constexpr absl::string_view kHostInputPipelinePrefix = "Host:"; @@ -199,8 +206,8 @@ void ProcessInputPipelines( root_iterator_event_map, TfDataStats* tf_data_stats) { auto* input_pipelines = tf_data_stats->mutable_input_pipelines(); - uint64 num_host_input_pipelines = 0; - uint64 num_device_input_pipelines = 0; + int64 num_host_input_pipelines = 0; + int64 num_device_input_pipelines = 0; for (auto& id_and_events : *root_iterator_event_map) { auto& root_iterator_id = id_and_events.first; auto& root_iterator_events = id_and_events.second; @@ -216,35 +223,208 @@ void ProcessInputPipelines( if (result.second) { bool is_device_input_pipeline = device_input_pipeline_ids.contains(root_iterator_id); - uint64 name_id = is_device_input_pipeline ? num_device_input_pipelines++ - : num_host_input_pipelines++; + int64 name_id = is_device_input_pipeline ? num_device_input_pipelines++ + : num_host_input_pipelines++; SetInputPipelineMetadata(root_iterator_id, name_id, is_device_input_pipeline, metadata); } - uint64 sum_latency_ps = 0; - uint64 min_latency_ps = UINT64_MAX; - uint64 max_latency_ps = 0; + int64 sum_latency_ps = 0; + int64 min_latency_ps = INT64_MAX; + int64 max_latency_ps = 0; + int64 num_slow_calls = 0; for (const EventNode* root_iterator_event : root_iterator_events) { InputPipelineStat* stat = input_pipeline_stats.add_stats(); ProcessIteratorEvent(*root_iterator_event, stat, /*is_blocking*/ true); SetBottleneckIteratorId(stat); - uint64 latency_ps = root_iterator_event->GetEventVisitor().DurationPs(); + int64 latency_ps = root_iterator_event->GetEventVisitor().DurationPs(); sum_latency_ps += latency_ps; min_latency_ps = std::min(min_latency_ps, latency_ps); max_latency_ps = std::max(max_latency_ps, latency_ps); + if (latency_ps > kSlowCallThresholdPs) num_slow_calls++; } input_pipeline_stats.set_avg_latency_ps(sum_latency_ps / root_iterator_events.size()); input_pipeline_stats.set_min_latency_ps(min_latency_ps); input_pipeline_stats.set_max_latency_ps(max_latency_ps); + input_pipeline_stats.set_num_slow_calls(num_slow_calls); + } +} + +void SetBottleneckAnalysis(absl::string_view host_name, + const TfDataStats& tf_data_stats, + TfDataBottleneckAnalysis* bottleneck_analysis) { + for (const auto& id_and_stats : tf_data_stats.input_pipelines()) { + const InputPipelineStats& input_pipeline_stats = id_and_stats.second; + if (input_pipeline_stats.metadata().type() == + InputPipelineMetadata::DEVICE || + input_pipeline_stats.max_latency_ps() <= + bottleneck_analysis->max_latency_ps()) { + // Ignore device input pipelines and input pipelines faster than the + // current bottleneck. + continue; + } + bottleneck_analysis->set_host(host_name.data(), host_name.size()); + bottleneck_analysis->set_input_pipeline( + input_pipeline_stats.metadata().name()); + bottleneck_analysis->set_max_latency_ps( + input_pipeline_stats.max_latency_ps()); + const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at( + input_pipeline_stats.stats(0).bottleneck_iterator_id()); + bottleneck_analysis->set_iterator_name(metadata.name()); + bottleneck_analysis->set_iterator_long_name(metadata.long_name()); + } +} + +std::string GetSuggestion(BottleneckType type) { + constexpr absl::string_view kPlaybookLink = + "https://www.tensorflow.org/guide/data_performance_analysis"; + constexpr absl::string_view kPlaybookSourceDatasetLink = + "https://www.tensorflow.org/guide/" + "data_performance_analysis#source_datasets"; + constexpr absl::string_view kPlaybookCpuUtilizationLink = + "https://www.tensorflow.org/guide/" + "data_performance_analysis#3_are_you_reaching_high_cpu_utilization"; + constexpr absl::string_view kPlaybookTransformationLink = + "https://www.tensorflow.org/guide/" + "data_performance_analysis#transformation_datasets"; + constexpr absl::string_view kTfGuideParallelDataExtractionLink = + "https://www.tensorflow.org/guide/" + "data_performance#parallelizing_data_extraction"; + constexpr absl::string_view kTfGuideParallelTransformationLink = + "https://www.tensorflow.org/guide/" + "data_performance#parallelizing_data_transformation"; + constexpr absl::string_view kTfGuideCacheLink = + "https://www.tensorflow.org/guide/data_performance#caching"; + switch (type) { + case BottleneckType::kSlowSource: + return absl::StrFormat( + "1. Check the locality of a host and input data. Ideally, they " + "should be in the same cell (or very close, like the same " + "region).
" + "2. Parallelize reading from this dataset source. See %s and %s for " + "more details.
", + AnchorElement(kPlaybookSourceDatasetLink, "here"), + AnchorElement(kTfGuideParallelDataExtractionLink, "here")); + case BottleneckType::kSlowRemoteSource: + return absl::StrFormat( + "1. The remote data source is slow. Profile its host to analyze the " + "issue further.
" + "2. See %s for other suggestions.", + AnchorElement(kPlaybookLink, "this")); + case BottleneckType::kSlowTransformationWithParallelVersion: + return absl::StrFormat( + "1. Parallelize this transformation by setting " + "num_parallel_calls=tf.data.experimental.AUTOTUNE. See " + "%s for more details.
" + "2. Consider adding cache after this transformation if " + "your data fits into memory and it is appropriate (e.g., there is no " + "randomness in upstream transformations like shuffle). " + "See %s for more details.
" + "3. Find more resources %s.", + AnchorElement(kTfGuideParallelTransformationLink, "this"), + AnchorElement(kTfGuideCacheLink, "this"), + AnchorElement(kPlaybookTransformationLink, "here")); + case BottleneckType::kSlowTransformationWithoutParallelVersion: + return absl::StrFormat( + "1. This transformation is inherently sequential. Add outer " + "parallelism by running multiple copies of the input pipeline over " + "sharded inputs and combining the results. See %s for more " + "details.
" + "2. Consider adding cache after this transformation if " + "your data fits into memory and it is appropriate (e.g., there is no " + "randomness in upstream transformations like shuffle). " + "See %s for more details.
" + "3. Find more resources %s.", + AnchorElement(kPlaybookTransformationLink, "this"), + AnchorElement(kTfGuideCacheLink, "this"), + AnchorElement(kPlaybookCpuUtilizationLink, "here")); + default: + return absl::StrFormat("See %s for suggestions.", + AnchorElement(kPlaybookLink, "this")); + } +} + +void SetSuggestion(TfDataBottleneckAnalysis* bottleneck_analysis) { + if (bottleneck_analysis->max_latency_ps() <= kSlowCallThresholdPs) return; + bottleneck_analysis->set_suggestion( + GetSuggestion(GetBottleneckType(bottleneck_analysis->iterator_name()))); +} + +void SetSummary(CombinedTfDataStats* combined_tf_data_stats) { + int64 max_latency_ps = + combined_tf_data_stats->bottleneck_analysis().max_latency_ps(); + if (max_latency_ps > kSlowCallThresholdPs) { + combined_tf_data_stats->set_is_input_bound(true); + combined_tf_data_stats->set_summary( + "Your profile has a tf.data input pipeline slower than 50 us. Below " + "shows a bottleneck in the slow input pipeline and a suggestion on how " + "to fix it."); + } else if (max_latency_ps > 0) { + combined_tf_data_stats->set_is_input_bound(false); + combined_tf_data_stats->set_summary( + "Your profile does not have any tf.data input pipeline slower than 50 " + "us. Your job could be still input bound if this profile didn't " + "capture all workers."); + } else { + combined_tf_data_stats->set_is_input_bound(false); + combined_tf_data_stats->set_summary( + "No tf.data activitiy captured in your profile. If your job uses " + "tf.data, try to capture a longer profile."); } } } // namespace -TfDataStats ConvertXPlaneToTfDataStats(XPlane* host_plane) { - TfDataStats tf_data_stats; +BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) { + static auto* kBottleneckTypeMap = new absl::flat_hash_map( + {// Read from storage. + {"TFRecord", BottleneckType::kSlowSource}, + {"SSTable", BottleneckType::kSlowSource}, + {"RecordIO", BottleneckType::kSlowSource}, + {"Spanner", BottleneckType::kSlowSource}, + {"TFColumn", BottleneckType::kSlowSource}, + {"SleepwalkRemoteDataset", BottleneckType::kSlowSource}, + {"TextLine", BottleneckType::kSlowSource}, + {"StitchedTimelineDataset", BottleneckType::kSlowSource}, + {"DateKeyDataset", BottleneckType::kSlowSource}, + {"CapacitorProto", BottleneckType::kSlowSource}, + {"LMDB", BottleneckType::kSlowSource}, + {"ExternalDataset", BottleneckType::kSlowSource}, + {"PearModel", BottleneckType::kSlowSource}, + {"FixedLengthRecordV2", BottleneckType::kSlowSource}, + // Read from local memory. + {"FromTensor", BottleneckType::kSlowSource}, + {"TensorSlice", BottleneckType::kSlowSource}, + {"Generator", BottleneckType::kSlowSource}, + {"SyntheticDatasetOp", BottleneckType::kSlowSource}, + // Read from remote memory. + {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource}, + {"ReverbDataset", BottleneckType::kSlowRemoteSource}, + {"DatasetService", BottleneckType::kSlowRemoteSource}, + {"DatasetSampleGame", BottleneckType::kSlowRemoteSource}, + {"Courier", BottleneckType::kSlowRemoteSource}, + {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource}, + // Transformations with parallel version. + {"Map", BottleneckType::kSlowTransformationWithParallelVersion}, + {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion}, + // Transformations without parallel version. + {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion}, + {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion}, + {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}}); + if (auto type = + gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) { + return *type; + } + return BottleneckType::kOther; +} + +void CombinedTfDataStatsBuilder::Add(absl::string_view host_name, + XPlane* host_plane) { + TfDataStats& tf_data_stats = + (*combined_tf_data_stats_ + ->mutable_tf_data_stats())[std::string(host_name)]; EventForest event_forest; event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane}); event_forest.ConnectEvents(); @@ -255,7 +435,19 @@ TfDataStats ConvertXPlaneToTfDataStats(XPlane* host_plane) { &root_iterator_event_map, &tf_data_stats); ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map, &tf_data_stats); - return tf_data_stats; +} + +void CombinedTfDataStatsBuilder::Finalize() { + TfDataBottleneckAnalysis* bottleneck_analysis = + combined_tf_data_stats_->mutable_bottleneck_analysis(); + for (const auto& host_name_and_tf_data_stats : + combined_tf_data_stats_->tf_data_stats()) { + SetBottleneckAnalysis(host_name_and_tf_data_stats.first, + host_name_and_tf_data_stats.second, + bottleneck_analysis); + } + if (generate_suggestion_) SetSuggestion(bottleneck_analysis); + SetSummary(combined_tf_data_stats_); } } // namespace profiler diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h index 486c198d735..5caca9d4770 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h @@ -16,13 +16,44 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { -TfDataStats ConvertXPlaneToTfDataStats(XPlane* host_plane); +TF_CONST_INIT extern const int64 kSlowCallThresholdPs; + +enum class BottleneckType { + kSlowSource, + kSlowRemoteSource, + kSlowTransformationWithParallelVersion, + kSlowTransformationWithoutParallelVersion, + kOther, +}; + +BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name); + +class CombinedTfDataStatsBuilder { + public: + explicit CombinedTfDataStatsBuilder( + CombinedTfDataStats* combined_tf_data_stats, + bool generate_suggestion = true) + : combined_tf_data_stats_(combined_tf_data_stats), + generate_suggestion_(generate_suggestion) {} + + void Add(absl::string_view host_name, XPlane* host_plane); + + // Finalizes by populating TfDataBottleneckAnalysis. + void Finalize(); + + private: + CombinedTfDataStats* combined_tf_data_stats_; + bool generate_suggestion_; +}; } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc index 5b597227c83..176db2d5469 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc @@ -45,115 +45,135 @@ TEST(XPlaneToTfDataStatsTest, HostInputPipeline) { auto consumer_thread = host_plane_builder.GetOrCreateLine(0); CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 100, {{StatType::kStepId, kPrefetchIteratorId}}); + 100000000, {{StatType::kStepId, kPrefetchIteratorId}}); CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 80, 20, + HostEventType::kPrefetchConsume, 80000000, 20000000, {{StatType::kElementId, kFirstElementId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 200, - 20, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", + 200000000, 20000000, {{StatType::kStepId, kPrefetchIteratorId}}); CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 210, 10, + HostEventType::kPrefetchConsume, 210000000, 10000000, {{StatType::kElementId, kSecondElementId}}); auto producer_thread = host_plane_builder.GetOrCreateLine(1); // Blocking producer. CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 0, 80, + HostEventType::kPrefetchProduce, 0, 80000000, {{StatType::kElementId, kFirstElementId}}); CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 0, 80, + "Iterator::Prefetch::Range", 0, 80000000, {{StatType::kStepId, kRangeIteratorId}, {StatType::kParentId, kPrefetchIteratorId}}); // Non-blocking producer. CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100, 80, + HostEventType::kPrefetchProduce, 100000000, 80000000, {{StatType::kElementId, kSecondElementId}}); CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Range", 100, 80, + "Iterator::Prefetch::Range", 100000000, 80000000, {{StatType::kStepId, kRangeIteratorId}, {StatType::kParentId, kPrefetchIteratorId}}); - TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); - EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::Prefetch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 60 - min_latency_ps: 20 - max_latency_ps: 100 - stats { - bottleneck_iterator_id: 456 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100 - self_time_ps: 20 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 80 - self_time_ps: 80 - is_blocking: true - num_calls: 1 - } - } + CombinedTfDataStats combined_tf_data_stats; + CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); + builder.Add("host1", &host_plane); + builder.Finalize(); + EXPECT_THAT( + combined_tf_data_stats, EqualsProto(R"pb( + bottleneck_analysis: { + host: "host1" + input_pipeline: "Host:0" + max_latency_ps: 100000000 + iterator_name: "Range" + iterator_long_name: "Iterator::Prefetch::Range" + suggestion: "See this for suggestions." + } + tf_data_stats: { + key: "host1" + value: { + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "Prefetch" + long_name: "Iterator::Prefetch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Range" + long_name: "Iterator::Prefetch::Range" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: HOST name: "Host:0" } + avg_latency_ps: 60000000 + min_latency_ps: 20000000 + max_latency_ps: 100000000 + num_slow_calls: 1 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 100000000 + self_time_ps: 20000000 + is_blocking: true + num_calls: 1 } - stats { - bottleneck_iterator_id: 123 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 200 - duration_ps: 20 - self_time_ps: 20 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100 - duration_ps: 80 - self_time_ps: 80 - is_blocking: false - num_calls: 1 - } - } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 0 + duration_ps: 80000000 + self_time_ps: 80000000 + is_blocking: true + num_calls: 1 } } } - )pb")); + stats { + bottleneck_iterator_id: 123 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 200000000 + duration_ps: 20000000 + self_time_ps: 20000000 + is_blocking: true + num_calls: 1 + } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 100000000 + duration_ps: 80000000 + self_time_ps: 80000000 + is_blocking: false + num_calls: 1 + } + } + } + } + } + } + } + is_input_bound: true + summary: "Your profile has a tf.data input pipeline slower than 50 us. Below shows a bottleneck in the slow input pipeline and a suggestion on how to fix it." + )pb")); } TEST(XPlaneToTfDataStatsTest, DeviceInputPipeline) { @@ -167,92 +187,106 @@ TEST(XPlaneToTfDataStatsTest, DeviceInputPipeline) { auto consumer_thread = host_plane_builder.GetOrCreateLine(0); CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, - 30, {{StatType::kStepId, kPrefetchIteratorId}}); - CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 100, - 100, {{StatType::kStepId, kPrefetchIteratorId}}); + 30000000, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", + 100000000, 100000000, + {{StatType::kStepId, kPrefetchIteratorId}}); CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kPrefetchConsume, 180, 20, + HostEventType::kPrefetchConsume, 180000000, 20000000, {{StatType::kElementId, kElementId}}); auto producer_thread = host_plane_builder.GetOrCreateLine(1); CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kPrefetchProduce, 100, 80, + HostEventType::kPrefetchProduce, 100000000, 80000000, {{StatType::kElementId, kElementId}}); CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::Prefetch::Generator", 100, 80, + "Iterator::Prefetch::Generator", 100000000, 80000000, {{StatType::kStepId, kRangeIteratorId}, {StatType::kParentId, kPrefetchIteratorId}}); - TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); - EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "Prefetch" - long_name: "Iterator::Prefetch" - is_async: true - } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Generator" - long_name: "Iterator::Prefetch::Generator" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: DEVICE name: "Device:0" } - avg_latency_ps: 65 - min_latency_ps: 30 - max_latency_ps: 100 - stats { - bottleneck_iterator_id: 456 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 100 - duration_ps: 100 - self_time_ps: 20 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 100 - duration_ps: 80 - self_time_ps: 80 - is_blocking: true - num_calls: 1 - } - } + CombinedTfDataStats combined_tf_data_stats; + CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); + builder.Add("host1", &host_plane); + builder.Finalize(); + // Device input pipeline is not considered for bottleneck analysis. + EXPECT_THAT( + combined_tf_data_stats, EqualsProto(R"pb( + bottleneck_analysis: {} + tf_data_stats: { + key: "host1" + value: { + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "Prefetch" + long_name: "Iterator::Prefetch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Generator" + long_name: "Iterator::Prefetch::Generator" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: DEVICE name: "Device:0" } + avg_latency_ps: 65000000 + min_latency_ps: 30000000 + max_latency_ps: 100000000 + num_slow_calls: 1 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 100000000 + duration_ps: 100000000 + self_time_ps: 20000000 + is_blocking: true + num_calls: 1 } - stats { - bottleneck_iterator_id: 123 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 30 - self_time_ps: 30 - is_blocking: true - num_calls: 1 - } - } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 100000000 + duration_ps: 80000000 + self_time_ps: 80000000 + is_blocking: true + num_calls: 1 } } } - )pb")); + stats { + bottleneck_iterator_id: 123 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 30000000 + self_time_ps: 30000000 + is_blocking: true + num_calls: 1 + } + } + } + } + } + } + } + summary: "No tf.data activitiy captured in your profile. If your job uses tf.data, try to capture a longer profile." + )pb")); } // Test with the following example dataset: @@ -272,83 +306,103 @@ TEST(XPlaneToTfDataStatsTest, MapAndBatch) { XLineBuilder consumer_thread = host_plane_builder.GetOrCreateLine(0); CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::MapAndBatch", - 0, 100, {{StatType::kStepId, kMapAndBatchIteratorId}}); + 0, 100000000, {{StatType::kStepId, kMapAndBatchIteratorId}}); CreateXEvent(&host_plane_builder, &consumer_thread, - HostEventType::kMapAndBatchConsume, 80, 20, + HostEventType::kMapAndBatchConsume, 80000000, 20000000, {{StatType::kElementId, kElementId}}); XLineBuilder producer_thread = host_plane_builder.GetOrCreateLine(1); CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 0, 30, + HostEventType::kMapAndBatchProduce, 0, 30000000, {{StatType::kElementId, kElementId}}); CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 0, 30, + "Iterator::MapAndBatch::Range", 0, 30000000, {{StatType::kStepId, kRangeIteratorId}, {StatType::kParentId, kMapAndBatchIteratorId}}); CreateXEvent(&host_plane_builder, &producer_thread, - HostEventType::kMapAndBatchProduce, 40, 30, + HostEventType::kMapAndBatchProduce, 40000000, 30000000, {{StatType::kElementId, kElementId}}); CreateXEvent(&host_plane_builder, &producer_thread, - "Iterator::MapAndBatch::Range", 40, 30, + "Iterator::MapAndBatch::Range", 40000000, 30000000, {{StatType::kStepId, kRangeIteratorId}, {StatType::kParentId, kMapAndBatchIteratorId}}); - TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); - EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( - iterator_metadata: { - key: 123, - value: { - id: 123 - name: "MapAndBatch" - long_name: "Iterator::MapAndBatch" - is_async: true + CombinedTfDataStats combined_tf_data_stats; + CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); + builder.Add("host1", &host_plane); + builder.Finalize(); + EXPECT_THAT( + combined_tf_data_stats, EqualsProto(R"pb( + bottleneck_analysis: { + host: "host1" + input_pipeline: "Host:0" + max_latency_ps: 100000000 + iterator_name: "Range" + iterator_long_name: "Iterator::MapAndBatch::Range" + suggestion: "See this for suggestions." + } + tf_data_stats: { + key: "host1" + value: { + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "MapAndBatch" + long_name: "Iterator::MapAndBatch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Range" + long_name: "Iterator::MapAndBatch::Range" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: HOST name: "Host:0" } + avg_latency_ps: 100000000 + min_latency_ps: 100000000 + max_latency_ps: 100000000 + num_slow_calls: 1 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 100000000 + self_time_ps: 40000000 + is_blocking: true + num_calls: 1 + } } - } - iterator_metadata: { - key: 456, - value: { - id: 456 - parent_id: 123 - name: "Range" - long_name: "Iterator::MapAndBatch::Range" - is_async: false - } - } - input_pipelines { - key: 123, - value: { - metadata { id: 123 type: HOST name: "Host:0" } - avg_latency_ps: 100 - min_latency_ps: 100 - max_latency_ps: 100 - stats { - bottleneck_iterator_id: 456 - iterator_stats { - key: 123, - value: { - id: 123 - start_time_ps: 0 - duration_ps: 100 - self_time_ps: 40 - is_blocking: true - num_calls: 1 - } - } - iterator_stats { - key: 456, - value: { - id: 456 - start_time_ps: 0 - duration_ps: 60 - self_time_ps: 60 - is_blocking: true - num_calls: 2 - } - } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 0 + duration_ps: 60000000 + self_time_ps: 60000000 + is_blocking: true + num_calls: 2 } } } - )pb")); + } + } + } + } + is_input_bound: true + summary: "Your profile has a tf.data input pipeline slower than 50 us. Below shows a bottleneck in the slow input pipeline and a suggestion on how to fix it." + )pb")); } } // namespace diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc index 59af75109d0..f1648744370 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -20,18 +20,25 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" #include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" +#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" #include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" +#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" #include "tensorflow/core/profiler/convert/xplane_to_trace_events.h" #include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/overview_page.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" +#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { @@ -153,6 +160,60 @@ std::pair ConvertXSpaceToMemoryProfile( return std::make_pair(json_output, true); } +std::pair ConvertMultiXSpacesToPodViewer( + const std::vector& xspace_paths) { + OpStatsOptions options; + options.generate_op_metrics_db = true; + options.generate_step_db = true; + OpStats combined_op_stats; + Status status = ConvertMultiXSpacesToCombinedOpStats(xspace_paths, options, + &combined_op_stats); + if (!status.ok()) { + LOG(WARNING) << "Could not generate OpStats for pod_viewer. Error: " + << status.error_message(); + return std::make_pair("", false); + } + + std::string json_output; + protobuf::util::JsonPrintOptions opts; + opts.always_print_primitive_fields = true; + auto encode_status = protobuf::util::MessageToJsonString( + ConvertOpStatsToPodViewer(combined_op_stats), &json_output, opts); + if (!encode_status.ok()) { + LOG(WARNING) << "Could not convert pod viewer proto to json. Error: " + << encode_status.message(); + return std::make_pair("", false); + } + return std::make_pair(json_output, true); +} + +std::pair ConvertMultiXSpacesToTfDataBottleneckAnalysis( + const std::vector& xspace_paths) { + CombinedTfDataStats combined_tf_data_stats; + CombinedTfDataStatsBuilder builder(&combined_tf_data_stats); + for (const std::string& xspace_path : xspace_paths) { + XSpace xspace; + Status status = ReadBinaryProto(Env::Default(), xspace_path, &xspace); + if (!status.ok()) { + LOG(WARNING) << "Could not read XSpace for tf data stats: " + << xspace_path; + return std::make_pair("", false); + } + XPlane* host_plane = + FindMutablePlaneWithName(&xspace, kHostThreadsPlaneName); + if (host_plane == nullptr) { + LOG(WARNING) << "Could not find host XPlane for tf data stats: " + << xspace_path; + return std::make_pair("", false); + } + absl::string_view host_name = + xspace.hostnames_size() ? xspace.hostnames(0) : xspace_path; + builder.Add(host_name, host_plane); + } + builder.Finalize(); + return std::make_pair(combined_tf_data_stats.SerializeAsString(), true); +} + } // namespace std::pair ConvertMultiXSpacesToToolData( @@ -170,6 +231,10 @@ std::pair ConvertMultiXSpacesToToolData( return ConvertMultiXSpacesToKernelStats(xspace_paths); } else if (tool_name == "memory_profile") { return ConvertXSpaceToMemoryProfile(xspace_paths); + } else if (tool_name == "pod_viewer") { + return ConvertMultiXSpacesToPodViewer(xspace_paths); + } else if (tool_name == "tf_data_bottleneck_analysis") { + return ConvertMultiXSpacesToTfDataBottleneckAnalysis(xspace_paths); } else { LOG(WARNING) << "Can not find tool: " << tool_name << ". Please update to " << "the latest version of Tensorflow."; diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc index 1e0d27ae68a..a9b6a29704f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc @@ -35,16 +35,16 @@ void CreateXSpace(XSpace* space) { thread1.AddEvent(*host_plane.GetOrCreateEventMetadata("event1")); event1.SetTimestampNs(150000); event1.SetDurationNs(10000); - event1.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), - "Relu"); + event1.AddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), + *host_plane.GetOrCreateStatMetadata("Relu")); XLineBuilder thread2 = host_plane.GetOrCreateLine(20); thread2.SetName("thread2"); XEventBuilder event2 = thread2.AddEvent(*host_plane.GetOrCreateEventMetadata("event2")); event2.SetTimestampNs(160000); event2.SetDurationNs(10000); - event2.ParseAndAddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), - "Conv2D"); + event2.AddStatValue(*host_plane.GetOrCreateStatMetadata("tf_op"), + *host_plane.GetOrCreateStatMetadata("Conv2D")); XPlaneBuilder device_plane(space->add_planes()); device_plane.SetName(GpuPlaneName(0)); @@ -55,8 +55,8 @@ void CreateXSpace(XSpace* space) { stream1.AddEvent(*device_plane.GetOrCreateEventMetadata("kernel1")); event3.SetTimestampNs(180000); event3.SetDurationNs(10000); - event3.ParseAndAddStatValue( - *device_plane.GetOrCreateStatMetadata("correlation id"), "55"); + event3.AddStatValue(*device_plane.GetOrCreateStatMetadata("correlation id"), + 55); } TEST(ConvertXPlaneToTraceEvents, Convert) { diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc index 1b48df0a650..1cd9f0bfe33 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc @@ -152,7 +152,7 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { ASSERT_EQ(plane.name(), kHostThreadsPlaneName); ASSERT_EQ(plane.lines_size(), 1); ASSERT_EQ(plane.event_metadata_size(), 7); - ASSERT_EQ(plane.stat_metadata_size(), 2); + ASSERT_EQ(plane.stat_metadata_size(), 4); const auto& line = plane.lines(0); EXPECT_EQ(line.id(), thread_id); EXPECT_EQ(line.name(), thread_name); diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index aa88db2b224..80c87b055a1 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -43,10 +43,6 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:profiler_factory", "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/utils:parse_annotation", - "//tensorflow/core/profiler/utils:xplane_builder", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -128,11 +124,11 @@ tf_cuda_library( copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ + ":cupti_collector", ":cupti_interface", ":cupti_utils", "//tensorflow/core:lib", "//tensorflow/core/profiler/internal/cpu:annotation_stack", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", @@ -140,6 +136,28 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "cupti_collector", + srcs = if_cuda_is_configured_compat(["cupti_collector.cc"]), + hdrs = if_cuda_is_configured_compat(["cupti_collector.h"]), + copts = tf_profiler_copts() + tf_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:parse_annotation", + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_library( name = "cupti_utils", srcs = if_cuda_is_configured_compat(["cupti_utils.cc"]), diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.cc b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc new file mode 100644 index 00000000000..2c5e62992c9 --- /dev/null +++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc @@ -0,0 +1,546 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/internal/gpu/cupti_collector.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/core/platform/abi.h" +#include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/profiler/utils/parse_annotation.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_utils.h" + +namespace tensorflow { +namespace profiler { + +namespace { + +bool IsHostEvent(const CuptiTracerEvent& event) { + // DriverCallback(i.e. kernel launching) events are host events. + if (event.source == CuptiTracerEventSource::DriverCallback) return true; + // Non-overhead activity events are device events. + if (event.type != CuptiTracerEventType::Overhead) return false; + // Overhead events can be associated with a thread or a stream, etc. + // If a valid thread id is specified, we consider it as a host event. + return event.thread_id != CuptiTracerEvent::kInvalidThreadId; +} + +void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, + uint64 start_gpu_ns, uint64 end_gpu_ns, XLineBuilder* line) { + if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || + event.start_time_ns > event.end_time_ns) { + VLOG(2) << "events have abnormal timestamps:" << event.name + << " start time(ns): " << event.start_time_ns + << " end time(ns): " << event.end_time_ns; + return; + } + std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); + if (kernel_name.empty()) { + kernel_name = GetTraceEventTypeName(event.type); + } + XEventMetadata* event_metadata = + plane->GetOrCreateEventMetadata(std::move(kernel_name)); + XEventBuilder xevent = line->AddEvent(*event_metadata); + xevent.SetTimestampNs(event.start_time_ns); + xevent.SetEndTimestampNs(event.end_time_ns); + if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCorrelationId)), + event.correlation_id); + } + if (!event.annotation.empty()) { + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kKernelAnnotation)), + *plane->GetOrCreateStatMetadata(event.annotation)); + } + if (event.context_id != CuptiTracerEvent::kInvalidContextId) { + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)), + absl::StrCat("$$", static_cast(event.context_id))); + } + if (event.type == CuptiTracerEventType::Kernel) { + std::string kernel_details = absl::StrCat( + "regs:", event.kernel_info.registers_per_thread, + " shm:", event.kernel_info.static_shared_memory_usage, + " grid:", event.kernel_info.grid_x, ",", event.kernel_info.grid_y, ",", + event.kernel_info.grid_z, " block:", event.kernel_info.block_x, ",", + event.kernel_info.block_y, ",", event.kernel_info.block_z); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kKernelDetails)), + *plane->GetOrCreateStatMetadata(kernel_details)); + } else if (event.type == CuptiTracerEventType::MemcpyH2D || + event.type == CuptiTracerEventType::MemcpyD2H || + event.type == CuptiTracerEventType::MemcpyD2D || + event.type == CuptiTracerEventType::MemcpyP2P || + event.type == CuptiTracerEventType::MemcpyOther) { + const auto& memcpy_info = event.memcpy_info; + std::string memcpy_details = absl::StrCat("size:", memcpy_info.num_bytes, + " dest:", memcpy_info.destination, + " async:", memcpy_info.async); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemcpyDetails)), + *plane->GetOrCreateStatMetadata(std::move(memcpy_details))); + } else if (event.type == CuptiTracerEventType::MemoryAlloc) { + std::string memalloc_details = + absl::StrCat("num_bytes:", event.memalloc_info.num_bytes); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMemallocDetails)), + *plane->GetOrCreateStatMetadata(std::move(memalloc_details))); + } + + std::vector annotation_stack = + ParseAnnotationStack(event.annotation); + // If multiple metadata have the same key name, show the values from the top + // of the stack (innermost annotation). Concatenate the values from "hlo_op". + absl::flat_hash_set key_set; + std::vector hlo_op_names; + for (auto annotation = annotation_stack.rbegin(); + annotation != annotation_stack.rend(); ++annotation) { + for (const Annotation::Metadata& metadata : annotation->metadata) { + if (metadata.key == "tf_op") { + continue; // ignored, obtained from HLO proto via DebugInfoMap + } else if (key_set.insert(metadata.key).second) { + xevent.ParseAndAddStatValue( + *plane->GetOrCreateStatMetadata(metadata.key), metadata.value); + } + } + } + // TODO(profiler): we should get rid of kLevel0, it is based on the assumption + // that those op-related ScopedAnnotation are at the very TOP level. + if (!annotation_stack.empty()) { + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kLevel0)), + *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name)); + } +} + +absl::optional GetDeviceAttribute(CUdevice device, + CUdevice_attribute attrib) { + int ret_val; + CUresult err = cuDeviceGetAttribute(&ret_val, attrib, device); + if (err != CUDA_SUCCESS) return absl::nullopt; + return ret_val; +} + +std::string GetDeviceXLineName( + int64 stream_id, absl::flat_hash_set& event_types) { + std::string line_name = absl::StrCat("Stream #", stream_id); + event_types.erase(CuptiTracerEventType::Unsupported); + if (event_types.empty()) return line_name; + std::vector type_names; + for (const auto event_type : event_types) { + type_names.emplace_back(GetTraceEventTypeName(event_type)); + } + return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")"); +} + +} // namespace + +void AnnotationMap::Add(uint32 device_id, uint32 correlation_id, + const std::string& annotation) { + if (annotation.empty()) return; + VLOG(3) << "Add annotation: device_id: " << device_id + << " correlation_id: " << correlation_id + << " annotation: " << annotation; + if (device_id >= per_device_map_.size()) return; + auto& per_device_map = per_device_map_[device_id]; + absl::MutexLock lock(&per_device_map.mutex); + if (per_device_map.annotations.size() < max_size_) { + absl::string_view annotation_str = + *per_device_map.annotations.insert(annotation).first; + per_device_map.correlation_map.emplace(correlation_id, annotation_str); + } +} + +absl::string_view AnnotationMap::LookUp(uint32 device_id, + uint32 correlation_id) { + if (device_id >= per_device_map_.size()) return absl::string_view(); + auto& per_device_map = per_device_map_[device_id]; + absl::MutexLock lock(&per_device_map.mutex); + auto it = per_device_map.correlation_map.find(correlation_id); + return it != per_device_map.correlation_map.end() ? it->second + : absl::string_view(); +} + +// CuptiTraceCollectorImpl store the CuptiTracerEvents from CuptiTracer and +// eventually convert and filter them to StepStats or XSpace. +class CuptiTraceCollectorImpl : public CuptiTraceCollector { + public: + CuptiTraceCollectorImpl(const CuptiTracerCollectorOptions& option, + uint64 start_walltime_ns, uint64 start_gpu_ns) + : CuptiTraceCollector(option), + num_callback_events_(0), + num_activity_events_(0), + start_walltime_ns_(start_walltime_ns), + start_gpu_ns_(start_gpu_ns), + num_gpus_(option.num_gpus), + per_device_collector_(option.num_gpus) {} + + void AddEvent(CuptiTracerEvent&& event) override { + if (event.device_id >= num_gpus_) return; + if (event.source == CuptiTracerEventSource::DriverCallback) { + if (num_callback_events_ > options_.max_callback_api_events) { + OnEventsDropped("total driver(callback) events reaches max", 1); + return; + } + num_callback_events_++; + } else { + if (num_activity_events_ > options_.max_activity_api_events) { + OnEventsDropped("total device(activity) events reaches max", 1); + return; + } + num_activity_events_++; + } + per_device_collector_[event.device_id].AddEvent(std::move(event)); + } + void OnEventsDropped(const std::string& reason, uint32 num_events) override { + absl::MutexLock lock(&mutex_); + dropped_events_[reason] += num_events; + } + void Flush() override {} + void Export(StepStats* step_stats) override { + LOG(INFO) << " GpuTracer has collected " << num_callback_events_ + << " callback api events and " << num_activity_events_ + << " activity events. " << ReportDroppedEvents(); + for (int i = 0; i < num_gpus_; ++i) { + per_device_collector_[i].Flush(i, start_walltime_ns_, start_gpu_ns_, + step_stats); + } + } + void Export(XSpace* space, uint64 end_gpu_ns) override { + LOG(INFO) << " GpuTracer has collected " << num_callback_events_ + << " callback api events and " << num_activity_events_ + << " activity events. " << ReportDroppedEvents(); + XPlaneBuilder host_plane( + FindOrAddMutablePlaneWithName(space, kCuptiDriverApiPlaneName)); + for (int device_ordinal = 0; device_ordinal < num_gpus_; ++device_ordinal) { + std::string name = GpuPlaneName(device_ordinal); + XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); + device_plane.SetId(device_ordinal); + per_device_collector_[device_ordinal].Flush(start_gpu_ns_, end_gpu_ns, + &device_plane, &host_plane); + per_device_collector_[device_ordinal].GetDeviceCapabilities( + device_ordinal, &device_plane); + NormalizeTimeStamps(&device_plane, start_walltime_ns_); + } + NormalizeTimeStamps(&host_plane, start_walltime_ns_); + } + std::string ReportDroppedEvents() { + absl::MutexLock lock(&mutex_); + string result; + for (const auto& dropped : dropped_events_) { + absl::StrAppend(&result, " ", dropped.second, " events dropped because ", + dropped.first, ";"); + } + if (!result.empty()) result.back() = '.'; + return result; + } + std::string ReportNumEventsIfDropped() override { + std::string events_dropped = ReportDroppedEvents(); + if (events_dropped.empty()) return ""; + return absl::StrCat("Detected GPU events dropped on ", port::Hostname(), + ": Profiler has collected ", + num_callback_events_.load(), " driver events and ", + num_activity_events_.load(), " device events.", + events_dropped); + } + + private: + std::atomic num_callback_events_; + std::atomic num_activity_events_; + absl::Mutex mutex_; + absl::flat_hash_map dropped_events_ + ABSL_GUARDED_BY(mutex_); + uint64 start_walltime_ns_; + uint64 start_gpu_ns_; + int num_gpus_; + + // Set the all XLines of specified XPlane to starting walltime. + // Events time in both host and device planes are CUTPI timestamps. + // We set initial CUPTI timestamp as start time for all lines to reflect + // this fact. Eventually we change line start time to corresponding + // start_walltime_ns to normalize with CPU wall time. + static void NormalizeTimeStamps(XPlaneBuilder* plane, + uint64 start_walltime_ns) { + plane->ForEachLine( + [&](XLineBuilder line) { line.SetTimestampNs(start_walltime_ns); }); + } + + struct CorrelationInfo { + CorrelationInfo(uint32 t, uint32 e) : thread_id(t), enqueue_time_ns(e) {} + uint32 thread_id; + uint64 enqueue_time_ns; + }; + struct PerDeviceCollector { + void AddEvent(CuptiTracerEvent&& event) { + mutex_lock l(m); + if (event.source == CuptiTracerEventSource::DriverCallback) { + // Cupti api callback events were used to populate launch times etc. + if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { + correlation_info.insert( + {event.correlation_id, + CorrelationInfo(event.thread_id, event.start_time_ns)}); + } + events.emplace_back(std::move(event)); + } else { + // Cupti activity events measure device times etc. + events.emplace_back(std::move(event)); + } + } + + void Flush(int32 device_ordinal, uint64 start_walltime_ns, + uint64 start_gpu_ns, StepStats* step_stats) { + mutex_lock l(m); + absl::flat_hash_map, + DeviceStepStats*> + stream_dev_stats_map; + DeviceStepStats* unknown_stream_dev_stats = nullptr; + DeviceStepStats* all_streams_dev_stats = nullptr; + DeviceStepStats* memcpy_dev_stats = nullptr; + DeviceStepStats* sync_dev_stats = nullptr; + for (const CuptiTracerEvent& event : events) { + NodeExecStats* ns = new NodeExecStats; + ns->set_all_start_micros( + (start_walltime_ns + (event.start_time_ns - start_gpu_ns)) / 1000); + ns->set_op_start_rel_micros(0); + auto elapsed_ns = event.end_time_ns - event.start_time_ns; + ns->set_op_end_rel_micros(elapsed_ns / 1000); + ns->set_all_end_rel_micros(elapsed_ns / 1000); + + if (event.source == CuptiTracerEventSource::DriverCallback) { + // Legacy code ignore all other launch events except + // cuStreamSynchronize. + if (event.name == "cuStreamSynchronize") { + ns->set_node_name(event.name); + ns->set_timeline_label(absl::StrCat("ThreadId ", event.thread_id)); + ns->set_thread_id(event.thread_id); + if (sync_dev_stats == nullptr) { + sync_dev_stats = step_stats->add_dev_stats(); + sync_dev_stats->set_device( + absl::StrCat("/device:GPU:", device_ordinal, "/sync")); + } + sync_dev_stats->add_node_stats()->Swap(ns); + } + } else { // CuptiTracerEventSource::Activity + // Get launch information if available. + if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { + auto it = correlation_info.find(event.correlation_id); + if (it != correlation_info.end()) { + ns->set_scheduled_micros(it->second.enqueue_time_ns / 1000); + ns->set_thread_id(it->second.thread_id); + } + } + + auto annotation_stack = ParseAnnotationStack(event.annotation); + std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); + std::string activity_name = + !annotation_stack.empty() + ? std::string(annotation_stack.back().name) + : kernel_name; + ns->set_node_name(activity_name); + switch (event.type) { + case CuptiTracerEventType::Kernel: { + ns->set_timeline_label(absl::StrCat( + kernel_name, " regs:", event.kernel_info.registers_per_thread, + " shm:", event.kernel_info.static_shared_memory_usage, + " grid: ", event.kernel_info.grid_x, ",", + event.kernel_info.grid_y, ",", event.kernel_info.grid_z, + " block:", event.kernel_info.block_x, ",", + event.kernel_info.block_y, ",", event.kernel_info.block_z, + "@@", event.annotation)); + DeviceStepStats*& stream_dev_stats = + stream_dev_stats_map[std::make_pair(event.stream_id, + event.type)]; + if (stream_dev_stats == nullptr) { + stream_dev_stats = step_stats->add_dev_stats(); + stream_dev_stats->set_device( + absl::StrCat("/device:GPU:", device_ordinal, + "/stream:", event.stream_id)); + } + *stream_dev_stats->add_node_stats() = *ns; + if (all_streams_dev_stats == nullptr) { + all_streams_dev_stats = step_stats->add_dev_stats(); + all_streams_dev_stats->set_device(absl::StrCat( + "/device:GPU:", device_ordinal, "/stream:all")); + } + all_streams_dev_stats->add_node_stats()->Swap(ns); + break; + } + case CuptiTracerEventType::MemcpyH2D: + case CuptiTracerEventType::MemcpyD2H: + case CuptiTracerEventType::MemcpyD2D: + case CuptiTracerEventType::MemcpyP2P: { + std::string details = absl::StrCat( + activity_name, " bytes:", event.memcpy_info.num_bytes); + if (event.memcpy_info.async) { + absl::StrAppend(&details, " aync"); + } + if (event.memcpy_info.destination != event.device_id) { + absl::StrAppend(&details, + " to device:", event.memcpy_info.destination); + } + ns->set_timeline_label(std::move(details)); + DeviceStepStats*& stream_dev_stats = + stream_dev_stats_map[std::make_pair(event.stream_id, + event.type)]; + if (stream_dev_stats == nullptr) { + stream_dev_stats = step_stats->add_dev_stats(); + stream_dev_stats->set_device(absl::StrCat( + "/device:GPU:", device_ordinal, "/stream:", event.stream_id, + "<", GetTraceEventTypeName(event.type), ">")); + } + *stream_dev_stats->add_node_stats() = *ns; + if (memcpy_dev_stats == nullptr) { + memcpy_dev_stats = step_stats->add_dev_stats(); + memcpy_dev_stats->set_device( + absl::StrCat("/device:GPU:", device_ordinal, "/memcpy")); + } + memcpy_dev_stats->add_node_stats()->Swap(ns); + break; + } + default: + ns->set_timeline_label(activity_name); + if (unknown_stream_dev_stats == nullptr) { + unknown_stream_dev_stats = step_stats->add_dev_stats(); + unknown_stream_dev_stats->set_device( + absl::StrCat("/device:GPU:", device_ordinal, "/stream:")); + } + unknown_stream_dev_stats->add_node_stats()->Swap(ns); + break; + } + } + } + events.clear(); + } + + void Flush(uint64 start_gpu_ns, uint64 end_gpu_ns, + XPlaneBuilder* device_plane, XPlaneBuilder* host_plane) { + mutex_lock l(m); + // Tracking event types per line. + absl::flat_hash_map> + events_types_per_line; + for (auto& event : events) { + bool is_host_event = IsHostEvent(event); + int64 line_id = is_host_event ? static_cast(event.thread_id) + : event.stream_id; + if (line_id == CuptiTracerEvent::kInvalidThreadId || + line_id == CuptiTracerEvent::kInvalidStreamId) + continue; + auto* plane = is_host_event ? host_plane : device_plane; + XLineBuilder line = plane->GetOrCreateLine(line_id); + line.SetTimestampNs(start_gpu_ns); + CreateXEvent(event, plane, start_gpu_ns, end_gpu_ns, &line); + events_types_per_line[line_id].emplace(event.type); + } + device_plane->ForEachLine([&](XLineBuilder line) { + line.SetName( + GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()])); + }); + host_plane->ForEachLine([&](XLineBuilder line) { + line.SetName(absl::StrCat("Host Threads/", line.Id())); + }); + events.clear(); + } + + void GetDeviceCapabilities(int32 device_ordinal, + XPlaneBuilder* device_plane) { + CUdevice device; + if (cuDeviceGet(&device, device_ordinal) != CUDA_SUCCESS) return; + + auto clock_rate_in_khz = + GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE); + if (clock_rate_in_khz) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapClockRateKHz)), + *clock_rate_in_khz); + } + + auto core_count = + GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); + if (core_count) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapCoreCount)), + *core_count); + } + + auto mem_clock_khz = + GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE); + auto mem_bus_width_bits = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH); + if (mem_clock_khz && mem_bus_width_bits) { + // Times 2 because HBM is DDR memory; it gets two data bits per each + // data lane. + auto memory_bandwidth = + uint64{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), + memory_bandwidth); + } + + size_t total_memory = 0; + if (cuDeviceTotalMem(&total_memory, device) == CUDA_SUCCESS) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemorySize)), + static_cast(total_memory)); + } + + auto compute_capability_major = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR); + if (compute_capability_major) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMajor)), + *compute_capability_major); + } + auto compute_capability_minor = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR); + if (compute_capability_minor) { + device_plane->AddStatValue( + *device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMinor)), + *compute_capability_minor); + } + } + + mutex m; + std::vector events TF_GUARDED_BY(m); + absl::flat_hash_map correlation_info + TF_GUARDED_BY(m); + }; + absl::FixedArray per_device_collector_; + + TF_DISALLOW_COPY_AND_ASSIGN(CuptiTraceCollectorImpl); +}; + +std::unique_ptr CreateCuptiCollector( + const CuptiTracerCollectorOptions& options, const uint64 start_walltime_ns, + const uint64 start_gputime_ns) { + return absl::make_unique(options, start_walltime_ns, + start_gputime_ns); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.h b/tensorflow/core/profiler/internal/gpu/cupti_collector.h new file mode 100644 index 00000000000..302312777e4 --- /dev/null +++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.h @@ -0,0 +1,204 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_COLLECTOR_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_COLLECTOR_H_ + +#include + +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +struct MemcpyDetails { + // The amount of data copied for memcpy events. + size_t num_bytes; + // The destination device for peer-2-peer communication (memcpy). The source + // device is implicit: it's the current device. + uint32 destination; + // Whether or not the memcpy is asynchronous. + bool async; + // This contains CUpti_ActivityMemcpyKind for activity event (on device). + // For events from other CuptiTracerEventSource, it is always 0. + int8 kind; + // CUpti_ActivityMemoryKind of source. + int8 src_mem_kind; + // CUpti_ActivityMemoryKind of destination. + int8 dst_mem_kind; +}; + +struct MemAllocDetails { + // The amount of data requested for cudaMalloc events. + uint64 num_bytes; +}; + +struct KernelDetails { + // The number of registers used in this kernel. + uint64 registers_per_thread; + // The amount of shared memory space used by a thread block. + uint64 static_shared_memory_usage; + // The amount of dynamic memory space used by a thread block. + uint64 dynamic_shared_memory_usage; + // X-dimension of a thread block. + uint64 block_x; + // Y-dimension of a thread block. + uint64 block_y; + // Z-dimension of a thread block. + uint64 block_z; + // X-dimension of a grid. + uint64 grid_x; + // Y-dimension of a grid. + uint64 grid_y; + // Z-dimension of a grid. + uint64 grid_z; +}; + +enum class CuptiTracerEventType { + Unsupported = 0, + Kernel = 1, + MemcpyH2D = 2, + MemcpyD2H = 3, + MemcpyD2D = 4, + MemcpyP2P = 5, + MemcpyOther = 6, + MemoryAlloc = 7, + Overhead = 8, + UnifiedMemory = 9, + Generic = 100, +}; + +const char* GetTraceEventTypeName(const CuptiTracerEventType& type); + +enum class CuptiTracerEventSource { + DriverCallback = 0, + Activity = 1, + // Maybe consider adding runtime callback and metric api in the future. +}; + +struct CuptiTracerEvent { + static constexpr uint32 kInvalidThreadId = + std::numeric_limits::max(); + static constexpr uint32 kInvalidCorrelationId = + std::numeric_limits::max(); + static constexpr uint64 kInvalidContextId = + std::numeric_limits::max(); + static constexpr uint64 kInvalidStreamId = + std::numeric_limits::max(); + CuptiTracerEventType type; + CuptiTracerEventSource source; + // Although CUpti_CallbackData::functionName is persistent, however + // CUpti_ActivityKernel4::name is not persistent, therefore we need a copy of + // it. + std::string name; + // This points to strings in AnnotationMap, which should outlive the point + // where serialization happens. + absl::string_view annotation; + uint64 start_time_ns; + uint64 end_time_ns; + uint32 device_id; + uint32 correlation_id = kInvalidCorrelationId; + uint32 thread_id = kInvalidThreadId; + int64 context_id = kInvalidContextId; + int64 stream_id = kInvalidStreamId; + union { + MemcpyDetails memcpy_info; // If type == Memcpy* + MemAllocDetails memalloc_info; // If type == MemoryAlloc + KernelDetails kernel_info; // If type == Kernel + }; +}; + +struct CuptiTracerCollectorOptions { + // Maximum number of events to collect from callback API; if -1, no limit. + // if 0, the callback API is enabled to build a correlation map, but no + // events are collected. + uint64 max_callback_api_events = 2 * 1024 * 1024; + // Maximum number of events to collect from activity API; if -1, no limit. + uint64 max_activity_api_events = 2 * 1024 * 1024; + // Maximum number of annotation strings that we can accommodate. + uint64 max_annotation_strings = 1024 * 1024; + // Number of GPUs involved. + uint32 num_gpus; +}; + +class AnnotationMap { + public: + explicit AnnotationMap(uint64 max_size, uint32 num_gpus) + : max_size_(max_size), per_device_map_(num_gpus) {} + void Add(uint32 device_id, uint32 correlation_id, + const std::string& annotation); + absl::string_view LookUp(uint32 device_id, uint32 correlation_id); + + private: + struct PerDeviceAnnotationMap { + // The population/consumption of annotations might happen from multiple + // callback/activity api related threads. + absl::Mutex mutex; + // Annotation tends to be repetitive, use a hash_set to store the strings, + // an use the reference to the string in the map. + absl::node_hash_set annotations; + absl::flat_hash_map correlation_map; + }; + const uint64 max_size_; + absl::FixedArray per_device_map_; + + TF_DISALLOW_COPY_AND_ASSIGN(AnnotationMap); +}; + +class CuptiTraceCollector { + public: + explicit CuptiTraceCollector(const CuptiTracerCollectorOptions& options) + : options_(options), + annotation_map_(options.max_annotation_strings, options.num_gpus) {} + virtual ~CuptiTraceCollector() {} + + // Producer side functions (i.e. called by CuptiTracer). + virtual void AddEvent(CuptiTracerEvent&& event) = 0; + virtual void OnEventsDropped(const std::string& reason, + uint32 num_events) = 0; + virtual void Flush() = 0; + + // Consumer side functions (i.e. called by GPU tracer); + virtual void Export(StepStats* step_stats) {} + virtual void Export(XSpace* space, uint64 end_gpu_ns) {} + virtual std::string ReportNumEventsIfDropped() { return ""; } + + AnnotationMap* annotation_map() { return &annotation_map_; } + + protected: + CuptiTracerCollectorOptions options_; + + private: + AnnotationMap annotation_map_; + + TF_DISALLOW_COPY_AND_ASSIGN(CuptiTraceCollector); +}; + +std::unique_ptr CreateCuptiCollector( + const CuptiTracerCollectorOptions& options, const uint64 start_walltime_ns, + const uint64 start_gputime_ns); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_COLLECTOR_H_ diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index aedb1722fad..d3e2b7d56b4 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -1344,32 +1344,6 @@ const char *GetTraceEventTypeName(const CuptiTracerEventType &type) { } } -void AnnotationMap::Add(uint32 device_id, uint32 correlation_id, - const std::string &annotation) { - if (annotation.empty()) return; - VLOG(3) << "Add annotation: device_id: " << device_id - << " correlation_id: " << correlation_id - << " annotation: " << annotation; - if (device_id >= per_device_map_.size()) return; - auto &per_device_map = per_device_map_[device_id]; - absl::MutexLock lock(&per_device_map.mutex); - if (per_device_map.annotations.size() < max_size_) { - absl::string_view annotation_str = - *per_device_map.annotations.insert(annotation).first; - per_device_map.correlation_map.emplace(correlation_id, annotation_str); - } -} - -absl::string_view AnnotationMap::LookUp(uint32 device_id, - uint32 correlation_id) { - if (device_id >= per_device_map_.size()) return absl::string_view(); - auto &per_device_map = per_device_map_[device_id]; - absl::MutexLock lock(&per_device_map.mutex); - auto it = per_device_map.correlation_map.find(correlation_id); - return it != per_device_map.correlation_map.end() ? it->second - : absl::string_view(); -} - /* static */ CuptiTracer *CuptiTracer::GetCuptiTracerSingleton() { static auto *singleton = new CuptiTracer(GetCuptiInterface()); return singleton; diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h index a62c08013e8..3f7a2d4d7e1 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h @@ -16,117 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_TRACER_H_ #define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_TRACER_H_ -#include "absl/container/fixed_array.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/node_hash_set.h" #include "absl/types/optional.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/internal/gpu/cupti_collector.h" #include "tensorflow/core/profiler/internal/gpu/cupti_interface.h" namespace tensorflow { namespace profiler { -struct MemcpyDetails { - // The amount of data copied for memcpy events. - size_t num_bytes; - // The destination device for peer-2-peer communication (memcpy). The source - // device is implicit: its the current device. - uint32 destination; - // Whether or not the memcpy is asynchronous. - bool async; - // This contains CUpti_ActivityMemcpyKind for activity event (on device). - // For events from other CuptiTracerEventSource, it is always 0. - int8 kind; - // CUpti_ActivityMemoryKind of source. - int8 src_mem_kind; - // CUpti_ActivityMemoryKind of destination. - int8 dst_mem_kind; -}; - -struct MemAllocDetails { - // The amount of data requested for cudaMalloc events. - uint64 num_bytes; -}; - -struct KernelDetails { - // The number of registers used in this kernel. - uint64 registers_per_thread; - // The amount of shared memory space used by a thread block. - uint64 static_shared_memory_usage; - // The amount of dynamic memory space used by a thread block. - uint64 dynamic_shared_memory_usage; - // X-dimension of a thread block. - uint64 block_x; - // Y-dimension of a thread block. - uint64 block_y; - // Z-dimension of a thread block. - uint64 block_z; - // X-dimension of a grid. - uint64 grid_x; - // Y-dimension of a grid. - uint64 grid_y; - // Z-dimension of a grid. - uint64 grid_z; -}; - -enum class CuptiTracerEventType { - Unsupported = 0, - Kernel = 1, - MemcpyH2D = 2, - MemcpyD2H = 3, - MemcpyD2D = 4, - MemcpyP2P = 5, - MemcpyOther = 6, - MemoryAlloc = 7, - Overhead = 8, - UnifiedMemory = 9, - Generic = 100, -}; - -const char* GetTraceEventTypeName(const CuptiTracerEventType& type); - -enum class CuptiTracerEventSource { - DriverCallback = 0, - Activity = 1, - // Maybe consider adding runtime callback and metric api in the future. -}; - -struct CuptiTracerEvent { - static constexpr uint32 kInvalidThreadId = - std::numeric_limits::max(); - static constexpr uint32 kInvalidCorrelationId = - std::numeric_limits::max(); - static constexpr uint64 kInvalidContextId = - std::numeric_limits::max(); - static constexpr uint64 kInvalidStreamId = - std::numeric_limits::max(); - CuptiTracerEventType type; - CuptiTracerEventSource source; - // Although CUpti_CallbackData::functionName is persistent, however - // CUpti_ActivityKernel4::name is not persistent, therefore we need a copy of - // it. - std::string name; - // This points to strings in AnnotationMap, which should outlive the point - // where serialization happens. - absl::string_view annotation; - uint64 start_time_ns; - uint64 end_time_ns; - uint32 device_id; - uint32 correlation_id = kInvalidCorrelationId; - uint32 thread_id = kInvalidThreadId; - int64 context_id = kInvalidContextId; - int64 stream_id = kInvalidStreamId; - union { - MemcpyDetails memcpy_info; // If type == Memcpy* - MemAllocDetails memalloc_info; // If type == MemoryAlloc - KernelDetails kernel_info; // If type == Kernel - }; -}; - struct CuptiTracerOptions { bool enable_activity_api = true; @@ -151,66 +52,6 @@ struct CuptiTracerOptions { bool sync_devices_before_stop = false; }; -struct CuptiTracerCollectorOptions { - // Maximum number of events to collect from callback API; if -1, no limit. - // if 0, the callback API is enabled to build a correlation map, but no - // events are collected. - uint64 max_callback_api_events = 2 * 1024 * 1024; - // Maximum number of events to collect from activity API; if -1, no limit. - uint64 max_activity_api_events = 2 * 1024 * 1024; - // Maximum number of annotation strings that we can accommodate. - uint64 max_annotation_strings = 1024 * 1024; - // Number of GPUs involved. - uint32 num_gpus; -}; - -class AnnotationMap { - public: - explicit AnnotationMap(uint64 max_size, uint32 num_gpus) - : max_size_(max_size), per_device_map_(num_gpus) {} - void Add(uint32 device_id, uint32 correlation_id, - const std::string& annotation); - absl::string_view LookUp(uint32 device_id, uint32 correlation_id); - - private: - struct PerDeviceAnnotationMap { - // The population/consumption of annotations might happen from multiple - // callback/activity api related threads. - absl::Mutex mutex; - // Annotation tends to be repetitive, use a hash_set to store the strings, - // an use the reference to the string in the map. - absl::node_hash_set annotations; - absl::flat_hash_map correlation_map; - }; - const uint64 max_size_; - absl::FixedArray per_device_map_; - - TF_DISALLOW_COPY_AND_ASSIGN(AnnotationMap); -}; - -class CuptiTraceCollector { - public: - explicit CuptiTraceCollector(const CuptiTracerCollectorOptions& options) - : options_(options), - annotation_map_(options.max_annotation_strings, options.num_gpus) {} - virtual ~CuptiTraceCollector() {} - - virtual void AddEvent(CuptiTracerEvent&& event) = 0; - virtual void OnEventsDropped(const std::string& reason, - uint32 num_events) = 0; - virtual void Flush() = 0; - - AnnotationMap* annotation_map() { return &annotation_map_; } - - protected: - CuptiTracerCollectorOptions options_; - - private: - AnnotationMap annotation_map_; - - TF_DISALLOW_COPY_AND_ASSIGN(CuptiTraceCollector); -}; - class CuptiDriverApiHook { public: virtual ~CuptiDriverApiHook() {} diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc index ed675e712a1..da5e5955389 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc @@ -23,510 +23,23 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" #include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" +#include "tensorflow/core/profiler/internal/gpu/cupti_collector.h" #include "tensorflow/core/profiler/internal/gpu/cupti_tracer.h" #include "tensorflow/core/profiler/internal/gpu/cupti_wrapper.h" #include "tensorflow/core/profiler/lib/profiler_factory.h" #include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/profiler/utils/parse_annotation.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace profiler { -namespace { - -bool IsHostEvent(const CuptiTracerEvent& event) { - // DriverCallback(i.e. kernel launching) events are host events. - if (event.source == CuptiTracerEventSource::DriverCallback) return true; - // Non-overhead activity events are device events. - if (event.type != CuptiTracerEventType::Overhead) return false; - // Overhead events can be associated with a thread or a stream, etc. - // If a valid thread id is specified, we consider it as a host event. - return event.thread_id != CuptiTracerEvent::kInvalidThreadId; -} - -void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, - uint64 start_gpu_ns, uint64 end_gpu_ns, XLineBuilder* line) { - if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || - event.start_time_ns > event.end_time_ns) { - VLOG(2) << "events have abnormal timestamps:" << event.name - << " start time(ns): " << event.start_time_ns - << " end time(ns): " << event.end_time_ns; - return; - } - std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); - if (kernel_name.empty()) { - kernel_name = GetTraceEventTypeName(event.type); - } - XEventMetadata* event_metadata = - plane->GetOrCreateEventMetadata(std::move(kernel_name)); - XEventBuilder xevent = line->AddEvent(*event_metadata); - xevent.SetTimestampNs(event.start_time_ns); - xevent.SetEndTimestampNs(event.end_time_ns); - if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kCorrelationId)), - event.correlation_id); - } - if (!event.annotation.empty()) { - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kKernelAnnotation)), - *plane->GetOrCreateStatMetadata(event.annotation)); - } - if (event.context_id != CuptiTracerEvent::kInvalidContextId) { - xevent.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)), - absl::StrCat("$$", static_cast(event.context_id))); - } - if (event.type == CuptiTracerEventType::Kernel) { - std::string kernel_details = absl::StrCat( - "regs:", event.kernel_info.registers_per_thread, - " shm:", event.kernel_info.static_shared_memory_usage, - " grid:", event.kernel_info.grid_x, ",", event.kernel_info.grid_y, ",", - event.kernel_info.grid_z, " block:", event.kernel_info.block_x, ",", - event.kernel_info.block_y, ",", event.kernel_info.block_z); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kKernelDetails)), - *plane->GetOrCreateStatMetadata(kernel_details)); - } else if (event.type == CuptiTracerEventType::MemcpyH2D || - event.type == CuptiTracerEventType::MemcpyD2H || - event.type == CuptiTracerEventType::MemcpyD2D || - event.type == CuptiTracerEventType::MemcpyP2P || - event.type == CuptiTracerEventType::MemcpyOther) { - const auto& memcpy_info = event.memcpy_info; - std::string memcpy_details = absl::StrCat("size:", memcpy_info.num_bytes, - " dest:", memcpy_info.destination, - " async:", memcpy_info.async); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemcpyDetails)), - memcpy_details); - } else if (event.type == CuptiTracerEventType::MemoryAlloc) { - std::string memalloc_details = - absl::StrCat("num_bytes:", event.memalloc_info.num_bytes); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kMemallocDetails)), - memalloc_details); - } - - std::vector annotation_stack = - ParseAnnotationStack(event.annotation); - // If multiple metadata have the same key name, show the values from the top - // of the stack (innermost annotation). Concatenate the values from "hlo_op". - absl::flat_hash_set key_set; - std::vector hlo_op_names; - for (auto annotation = annotation_stack.rbegin(); - annotation != annotation_stack.rend(); ++annotation) { - for (const Annotation::Metadata& metadata : annotation->metadata) { - if (metadata.key == "tf_op") { - continue; // ignored, obtained from HLO proto via DebugInfoMap - } else if (key_set.insert(metadata.key).second) { - xevent.ParseAndAddStatValue( - *plane->GetOrCreateStatMetadata(metadata.key), metadata.value); - } - } - } - // TODO(profiler): we should get rid of kLevel0, it is based on the assumption - // that those op-related ScopedAnnotation are at the very TOP level. - if (!annotation_stack.empty()) { - xevent.AddStatValue( - *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kLevel0)), - *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name)); - } -} - -absl::optional GetDeviceAttribute(CUdevice device, - CUdevice_attribute attrib) { - int ret_val; - CUresult err = cuDeviceGetAttribute(&ret_val, attrib, device); - if (err != CUDA_SUCCESS) return absl::nullopt; - return ret_val; -} - -std::string GetDeviceXLineName( - int64 stream_id, absl::flat_hash_set& event_types) { - std::string line_name = absl::StrCat("Stream #", stream_id); - event_types.erase(CuptiTracerEventType::Unsupported); - if (event_types.empty()) return line_name; - std::vector type_names; - for (const auto event_type : event_types) { - type_names.emplace_back(GetTraceEventTypeName(event_type)); - } - return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")"); -} - -} // namespace - -// CuptiTraceCollectorImpl store the CuptiTracerEvents from CuptiTracer and -// eventually convert and filter them to StepStats or XSpace. -class CuptiTraceCollectorImpl : public CuptiTraceCollector { - public: - CuptiTraceCollectorImpl(const CuptiTracerCollectorOptions& option, - uint64 start_walltime_ns, uint64 start_gpu_ns) - : CuptiTraceCollector(option), - num_callback_events_(0), - num_activity_events_(0), - start_walltime_ns_(start_walltime_ns), - start_gpu_ns_(start_gpu_ns), - num_gpus_(option.num_gpus), - per_device_collector_(option.num_gpus) {} - - void AddEvent(CuptiTracerEvent&& event) override { - if (event.device_id >= num_gpus_) return; - if (event.source == CuptiTracerEventSource::DriverCallback) { - if (num_callback_events_ > options_.max_callback_api_events) { - OnEventsDropped("total driver(callback) events reaches max", 1); - return; - } - num_callback_events_++; - } else { - if (num_activity_events_ > options_.max_activity_api_events) { - OnEventsDropped("total device(activity) events reaches max", 1); - return; - } - num_activity_events_++; - } - per_device_collector_[event.device_id].AddEvent(std::move(event)); - } - void OnEventsDropped(const std::string& reason, uint32 num_events) override { - absl::MutexLock lock(&mutex_); - dropped_events_[reason] += num_events; - } - void Flush() override {} - void Export(StepStats* step_stats) { - LOG(INFO) << " GpuTracer has collected " << num_callback_events_ - << " callback api events and " << num_activity_events_ - << " activity events. " << ReportDroppedEvents(); - for (int i = 0; i < num_gpus_; ++i) { - per_device_collector_[i].Flush(i, start_walltime_ns_, start_gpu_ns_, - step_stats); - } - } - void Export(XSpace* space) { - LOG(INFO) << " GpuTracer has collected " << num_callback_events_ - << " callback api events and " << num_activity_events_ - << " activity events. " << ReportDroppedEvents(); - uint64 end_gpu_ns = CuptiTracer::GetTimestamp(); - XPlaneBuilder host_plane( - FindOrAddMutablePlaneWithName(space, kCuptiDriverApiPlaneName)); - for (int device_ordinal = 0; device_ordinal < num_gpus_; ++device_ordinal) { - std::string name = GpuPlaneName(device_ordinal); - XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); - device_plane.SetId(device_ordinal); - per_device_collector_[device_ordinal].Flush(start_gpu_ns_, end_gpu_ns, - &device_plane, &host_plane); - per_device_collector_[device_ordinal].GetDeviceCapabilities( - device_ordinal, &device_plane); - NormalizeTimeStamps(&device_plane, start_walltime_ns_); - } - NormalizeTimeStamps(&host_plane, start_walltime_ns_); - } - std::string ReportDroppedEvents() { - absl::MutexLock lock(&mutex_); - string result; - for (const auto& dropped : dropped_events_) { - absl::StrAppend(&result, " ", dropped.second, " events dropped because ", - dropped.first, ";"); - } - if (!result.empty()) result.back() = '.'; - return result; - } - std::string ReportNumEventsIfDropped() { - std::string events_dropped = ReportDroppedEvents(); - if (events_dropped.empty()) return ""; - return absl::StrCat("Detected GPU events dropped on ", port::Hostname(), - ": Profiler has collected ", - num_callback_events_.load(), " driver events and ", - num_activity_events_.load(), " device events.", - events_dropped); - } - - private: - std::atomic num_callback_events_; - std::atomic num_activity_events_; - absl::Mutex mutex_; - absl::flat_hash_map dropped_events_ - ABSL_GUARDED_BY(mutex_); - uint64 start_walltime_ns_; - uint64 start_gpu_ns_; - int num_gpus_; - - // Set the all XLines of specified XPlane to starting walltime. - // Events time in both host and device planes are CUTPI timestamps. - // We set initial CUPTI timestamp as start time for all lines to reflect - // this fact. Eventually we change line start time to corresponding - // start_walltime_ns to normalize with CPU wall time. - static void NormalizeTimeStamps(XPlaneBuilder* plane, - uint64 start_walltime_ns) { - plane->ForEachLine( - [&](XLineBuilder line) { line.SetTimestampNs(start_walltime_ns); }); - } - - struct CorrelationInfo { - CorrelationInfo(uint32 t, uint32 e) : thread_id(t), enqueue_time_ns(e) {} - uint32 thread_id; - uint64 enqueue_time_ns; - }; - struct PerDeviceCollector { - void AddEvent(CuptiTracerEvent&& event) { - mutex_lock l(m); - if (event.source == CuptiTracerEventSource::DriverCallback) { - // Cupti api callback events were used to populate launch times etc. - if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { - correlation_info.insert( - {event.correlation_id, - CorrelationInfo(event.thread_id, event.start_time_ns)}); - } - events.emplace_back(std::move(event)); - } else { - // Cupti activity events measure device times etc. - events.emplace_back(std::move(event)); - } - } - - void Flush(int32 device_ordinal, uint64 start_walltime_ns, - uint64 start_gpu_ns, StepStats* step_stats) { - mutex_lock l(m); - absl::flat_hash_map, - DeviceStepStats*> - stream_dev_stats_map; - DeviceStepStats* unknown_stream_dev_stats = nullptr; - DeviceStepStats* all_streams_dev_stats = nullptr; - DeviceStepStats* memcpy_dev_stats = nullptr; - DeviceStepStats* sync_dev_stats = nullptr; - for (const CuptiTracerEvent& event : events) { - NodeExecStats* ns = new NodeExecStats; - ns->set_all_start_micros( - (start_walltime_ns + (event.start_time_ns - start_gpu_ns)) / 1000); - ns->set_op_start_rel_micros(0); - auto elapsed_ns = event.end_time_ns - event.start_time_ns; - ns->set_op_end_rel_micros(elapsed_ns / 1000); - ns->set_all_end_rel_micros(elapsed_ns / 1000); - - if (event.source == CuptiTracerEventSource::DriverCallback) { - // Legacy code ignore all other launch events except - // cuStreamSynchronize. - if (event.name == "cuStreamSynchronize") { - ns->set_node_name(event.name); - ns->set_timeline_label(absl::StrCat("ThreadId ", event.thread_id)); - ns->set_thread_id(event.thread_id); - if (sync_dev_stats == nullptr) { - sync_dev_stats = step_stats->add_dev_stats(); - sync_dev_stats->set_device( - absl::StrCat("/device:GPU:", device_ordinal, "/sync")); - } - sync_dev_stats->add_node_stats()->Swap(ns); - } - } else { // CuptiTracerEventSource::Activity - // Get launch information if available. - if (event.correlation_id != CuptiTracerEvent::kInvalidCorrelationId) { - auto it = correlation_info.find(event.correlation_id); - if (it != correlation_info.end()) { - ns->set_scheduled_micros(it->second.enqueue_time_ns / 1000); - ns->set_thread_id(it->second.thread_id); - } - } - - auto annotation_stack = ParseAnnotationStack(event.annotation); - std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); - std::string activity_name = - !annotation_stack.empty() - ? std::string(annotation_stack.back().name) - : kernel_name; - ns->set_node_name(activity_name); - switch (event.type) { - case CuptiTracerEventType::Kernel: { - ns->set_timeline_label(absl::StrCat( - kernel_name, " regs:", event.kernel_info.registers_per_thread, - " shm:", event.kernel_info.static_shared_memory_usage, - " grid: ", event.kernel_info.grid_x, ",", - event.kernel_info.grid_y, ",", event.kernel_info.grid_z, - " block:", event.kernel_info.block_x, ",", - event.kernel_info.block_y, ",", event.kernel_info.block_z, - "@@", event.annotation)); - DeviceStepStats*& stream_dev_stats = - stream_dev_stats_map[std::make_pair(event.stream_id, - event.type)]; - if (stream_dev_stats == nullptr) { - stream_dev_stats = step_stats->add_dev_stats(); - stream_dev_stats->set_device( - absl::StrCat("/device:GPU:", device_ordinal, - "/stream:", event.stream_id)); - } - *stream_dev_stats->add_node_stats() = *ns; - if (all_streams_dev_stats == nullptr) { - all_streams_dev_stats = step_stats->add_dev_stats(); - all_streams_dev_stats->set_device(absl::StrCat( - "/device:GPU:", device_ordinal, "/stream:all")); - } - all_streams_dev_stats->add_node_stats()->Swap(ns); - break; - } - case CuptiTracerEventType::MemcpyH2D: - case CuptiTracerEventType::MemcpyD2H: - case CuptiTracerEventType::MemcpyD2D: - case CuptiTracerEventType::MemcpyP2P: { - std::string details = absl::StrCat( - activity_name, " bytes:", event.memcpy_info.num_bytes); - if (event.memcpy_info.async) { - absl::StrAppend(&details, " aync"); - } - if (event.memcpy_info.destination != event.device_id) { - absl::StrAppend(&details, - " to device:", event.memcpy_info.destination); - } - ns->set_timeline_label(std::move(details)); - DeviceStepStats*& stream_dev_stats = - stream_dev_stats_map[std::make_pair(event.stream_id, - event.type)]; - if (stream_dev_stats == nullptr) { - stream_dev_stats = step_stats->add_dev_stats(); - stream_dev_stats->set_device(absl::StrCat( - "/device:GPU:", device_ordinal, "/stream:", event.stream_id, - "<", GetTraceEventTypeName(event.type), ">")); - } - *stream_dev_stats->add_node_stats() = *ns; - if (memcpy_dev_stats == nullptr) { - memcpy_dev_stats = step_stats->add_dev_stats(); - memcpy_dev_stats->set_device( - absl::StrCat("/device:GPU:", device_ordinal, "/memcpy")); - } - memcpy_dev_stats->add_node_stats()->Swap(ns); - break; - } - default: - ns->set_timeline_label(activity_name); - if (unknown_stream_dev_stats == nullptr) { - unknown_stream_dev_stats = step_stats->add_dev_stats(); - unknown_stream_dev_stats->set_device( - absl::StrCat("/device:GPU:", device_ordinal, "/stream:")); - } - unknown_stream_dev_stats->add_node_stats()->Swap(ns); - break; - } - } - } - events.clear(); - } - - void Flush(uint64 start_gpu_ns, uint64 end_gpu_ns, - XPlaneBuilder* device_plane, XPlaneBuilder* host_plane) { - mutex_lock l(m); - // Tracking event types per line. - absl::flat_hash_map> - events_types_per_line; - for (auto& event : events) { - bool is_host_event = IsHostEvent(event); - int64 line_id = is_host_event ? static_cast(event.thread_id) - : event.stream_id; - if (line_id == CuptiTracerEvent::kInvalidThreadId || - line_id == CuptiTracerEvent::kInvalidStreamId) - continue; - auto* plane = is_host_event ? host_plane : device_plane; - XLineBuilder line = plane->GetOrCreateLine(line_id); - line.SetTimestampNs(start_gpu_ns); - CreateXEvent(event, plane, start_gpu_ns, end_gpu_ns, &line); - events_types_per_line[line_id].emplace(event.type); - } - device_plane->ForEachLine([&](XLineBuilder line) { - line.SetName( - GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()])); - }); - host_plane->ForEachLine([&](XLineBuilder line) { - line.SetName(absl::StrCat("Host Threads/", line.Id())); - }); - events.clear(); - } - - void GetDeviceCapabilities(int32 device_ordinal, - XPlaneBuilder* device_plane) { - CUdevice device; - if (cuDeviceGet(&device, device_ordinal) != CUDA_SUCCESS) return; - - auto clock_rate_in_khz = - GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE); - if (clock_rate_in_khz) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapClockRateKHz)), - *clock_rate_in_khz); - } - - auto core_count = - GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); - if (core_count) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapCoreCount)), - *core_count); - } - - auto mem_clock_khz = - GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE); - auto mem_bus_width_bits = GetDeviceAttribute( - device, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH); - if (mem_clock_khz && mem_bus_width_bits) { - // Times 2 because HBM is DDR memory; it gets two data bits per each - // data lane. - auto memory_bandwidth = - uint64{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), - memory_bandwidth); - } - - size_t total_memory = 0; - if (cuDeviceTotalMem(&total_memory, device) == CUDA_SUCCESS) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapMemorySize)), - static_cast(total_memory)); - } - - auto compute_capability_major = GetDeviceAttribute( - device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR); - if (compute_capability_major) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMajor)), - *compute_capability_major); - } - auto compute_capability_minor = GetDeviceAttribute( - device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR); - if (compute_capability_minor) { - device_plane->AddStatValue( - *device_plane->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kDevCapComputeCapMinor)), - *compute_capability_minor); - } - } - - mutex m; - std::vector events TF_GUARDED_BY(m); - absl::flat_hash_map correlation_info - TF_GUARDED_BY(m); - }; - absl::FixedArray per_device_collector_; - - TF_DISALLOW_COPY_AND_ASSIGN(CuptiTraceCollectorImpl); -}; - // GpuTracer for GPU. class GpuTracer : public profiler::ProfilerInterface { public: @@ -557,7 +70,7 @@ class GpuTracer : public profiler::ProfilerInterface { CuptiTracer* cupti_tracer_; CuptiTracerOptions options_; - std::unique_ptr cupti_collector_; + std::unique_ptr cupti_collector_; }; Status GpuTracer::DoStart() { @@ -621,8 +134,8 @@ Status GpuTracer::DoStart() { collector_options.num_gpus = cupti_tracer_->NumGpus(); uint64 start_gputime_ns = CuptiTracer::GetTimestamp(); uint64 start_walltime_ns = tensorflow::EnvTime::NowNanos(); - cupti_collector_ = absl::make_unique( - collector_options, start_walltime_ns, start_gputime_ns); + cupti_collector_ = CreateCuptiCollector(collector_options, start_walltime_ns, + start_gputime_ns); AnnotationStack::Enable(true); cupti_tracer_->Enable(options_, cupti_collector_.get()); @@ -683,6 +196,7 @@ Status GpuTracer::CollectData(RunMetadata* run_metadata) { } Status GpuTracer::CollectData(XSpace* space) { + VLOG(2) << "Collecting data to XSpace from GpuTracer."; switch (profiling_state_) { case State::kNotStarted: VLOG(1) << "No trace data collected, session wasn't started"; @@ -705,7 +219,8 @@ Status GpuTracer::CollectData(XSpace* space) { space->add_warnings(std::move(events_dropped)); } if (cupti_collector_) { - cupti_collector_->Export(space); + uint64 end_gpu_ns = CuptiTracer::GetTimestamp(); + cupti_collector_->Export(space, end_gpu_ns); } return Status::OK(); } @@ -716,7 +231,7 @@ Status GpuTracer::CollectData(XSpace* space) { // Not in anonymous namespace for testing purposes. std::unique_ptr CreateGpuTracer( const ProfileOptions& options) { - VLOG(2) << "Collecting data to XSpace from GpuTracer."; + if (options.device_tracer_level() == 0) return nullptr; if (options.device_type() != ProfileOptions::GPU && options.device_type() != ProfileOptions::UNSPECIFIED) return nullptr; diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h index c179f1710c9..976ebcfc884 100644 --- a/tensorflow/core/profiler/lib/profiler_session.h +++ b/tensorflow/core/profiler/lib/profiler_session.h @@ -38,7 +38,7 @@ namespace tensorflow { // Thread-safety: ProfilerSession is thread-safe. class ProfilerSession { public: - // Creates and ProfilerSession and starts profiling. + // Creates a ProfilerSession and starts profiling. static std::unique_ptr Create(const ProfileOptions& options); static ProfileOptions DefaultOptions() { diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index ca3ea735765..976fcfc82dd 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -192,6 +192,23 @@ class TraceMe { // Static API, for use when scoped objects are inconvenient. + // Record the start time of an activity. + // Returns the activity ID, which is used to stop the activity. + // Calls `name_generator` to get the name for activity. + template + static uint64 ActivityStart(NameGeneratorT name_generator, int level = 1) { +#if !defined(IS_MOBILE_PLATFORM) + if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) { + uint64 activity_id = TraceMeRecorder::NewActivityId(); + TraceMeRecorder::Record({activity_id, name_generator(), + /*start_time=*/EnvTime::NowNanos(), + /*end_time=*/0}); + return activity_id; + } +#endif + return kUntracedActivity; + } + // Record the start time of an activity. // Returns the activity ID, which is used to stop the activity. static uint64 ActivityStart(absl::string_view name, int level = 1) { @@ -207,6 +224,16 @@ class TraceMe { return kUntracedActivity; } + // Same as ActivityStart above, an overload for "const std::string&" + static uint64 ActivityStart(const std::string& name, int level = 1) { + return ActivityStart(absl::string_view(name), level); + } + + // Same as ActivityStart above, an overload for "const char*" + static uint64 ActivityStart(const char* name, int level = 1) { + return ActivityStart(absl::string_view(name), level); + } + // Record the end time of an activity started by ActivityStart(). static void ActivityEnd(uint64 activity_id) { #if !defined(IS_MOBILE_PLATFORM) diff --git a/tensorflow/core/profiler/protobuf/tf_data_stats.proto b/tensorflow/core/profiler/protobuf/tf_data_stats.proto index f4edc4144d4..25c19614fc1 100644 --- a/tensorflow/core/profiler/protobuf/tf_data_stats.proto +++ b/tensorflow/core/profiler/protobuf/tf_data_stats.proto @@ -76,6 +76,8 @@ message InputPipelineStats { int64 min_latency_ps = 4; // Maximum latency of the input pipeline. int64 max_latency_ps = 5; + // The number of times this input pipeline was slower than 50 us. + int64 num_slow_calls = 6; // Stats per call sorted by the root iterator's duration. repeated InputPipelineStat stats = 2; } @@ -87,3 +89,30 @@ message TfDataStats { // Stats per input pipeline. map input_pipelines = 1; } + +message TfDataBottleneckAnalysis { + // Host name. + string host = 1; + // Input pipeline name. + string input_pipeline = 2; + // Maximum latency of the input pipeline. + int64 max_latency_ps = 3; + // Name of the bottleneck iterator. + string iterator_name = 4; + // Long name of the bottleneck iterator. + string iterator_long_name = 5; + // Suggestion to resolve the bottleneck. + string suggestion = 6; +} + +// TfDataStats of all hosts. +message CombinedTfDataStats { + // Whether it is input bound. + bool is_input_bound = 3; + // Summary of the analysis. + string summary = 4; + // Bottleneck analysis result. + TfDataBottleneckAnalysis bottleneck_analysis = 1; + // TfDataStats per host. + map tf_data_stats = 2; +} diff --git a/tensorflow/core/profiler/protobuf/xplane.proto b/tensorflow/core/profiler/protobuf/xplane.proto index dd34c2f40b1..f57d7609891 100644 --- a/tensorflow/core/profiler/protobuf/xplane.proto +++ b/tensorflow/core/profiler/protobuf/xplane.proto @@ -5,13 +5,15 @@ package tensorflow.profiler; option cc_enable_arenas = true; // A container of parallel XPlanes, generated by one or more profiling sources. -// Next ID: 4 +// Next ID: 5 message XSpace { repeated XPlane planes = 1; // Errors (if any) in the generation of planes. repeated string errors = 2; // Warnings (if any) in the generation of planes; repeated string warnings = 3; + // List of hostnames that XPlanes are generated from. + repeated string hostnames = 4; } // An XPlane is a container of parallel timelines (XLines), generated by a diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index efe9e6eec9a..ca1ff506f2a 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -116,7 +116,6 @@ cc_library( tf_cc_test( name = "profiler_client_test", srcs = ["profiler_client_test.cc"], - tags = ["external"], # So that test suite reruns unconditionally. deps = [ ":profiler_client", ":profiler_client_impl", # for oss @@ -149,7 +148,6 @@ cc_library( tf_cc_test( name = "remote_profiler_session_manager_test", srcs = ["remote_profiler_session_manager_test.cc"], - tags = ["external"], # So that test suite reruns unconditionally. deps = [ ":profiler_client_impl", # for oss ":profiler_client_test_util", diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index 54eedb65fa0..e8690f1f1f8 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -50,6 +50,7 @@ Status CollectDataToRepository(const ProfileRequest& request, // Read the profile data into xspace. XSpace xspace; TF_RETURN_IF_ERROR(profiler->CollectData(&xspace)); + xspace.add_hostnames(request.host_name()); VLOG(3) << "Collected XSpace to repository."; response->set_empty_trace(IsEmpty(xspace)); diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index f7c6d5496d5..7476c5aa0c5 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -240,8 +240,7 @@ cc_library( ":trace_utils", ":xplane_builder", ":xplane_visitor", - "//tensorflow/core:platform_base", - "//tensorflow/core/platform:types", + "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -255,7 +254,7 @@ tf_cc_test( ":xplane_builder", ":xplane_utils", ":xplane_visitor", - "//tensorflow/core:platform_base", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", @@ -276,7 +275,7 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", - "//tensorflow/core/platform:types", + "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -527,3 +526,21 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "device_caps_utils", + srcs = ["device_caps_utils.cc"], + hdrs = ["device_caps_utils.h"], + copts = tf_profiler_copts(), + visibility = [":friends"], + deps = [ + ":xplane_builder", + ":xplane_schema", + ":xplane_visitor", + "//tensorflow/core:platform_base", + "//tensorflow/core/platform:types", + "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc new file mode 100644 index 00000000000..44ab1692422 --- /dev/null +++ b/tensorflow/core/profiler/utils/device_caps_utils.cc @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/device_caps_utils.h" + +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane) { + XPlaneBuilder xplane(plane); + int clock_rate_in_khz = + static_cast(caps.clock_rate_in_ghz() * 1000000.0); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapClockRateKHz)), + clock_rate_in_khz); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapCoreCount)), + caps.num_cores()); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), + caps.memory_bandwidth()); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapMemorySize)), + caps.memory_size_in_bytes()); + if (caps.has_compute_capability()) { + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMajor)), + caps.compute_capability().major()); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDevCapComputeCapMinor)), + caps.compute_capability().minor()); + } +} + +DeviceCapabilities GetDeviceCaps(const XPlane& plane) { + DeviceCapabilities caps; + XPlaneVisitor xplane(&plane); + xplane.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) { + if (!stat.Type().has_value()) return; + switch (stat.Type().value()) { + case StatType::kDevCapClockRateKHz: + caps.set_clock_rate_in_ghz(stat.IntOrUintValue() * 1000000.0); + break; + case StatType::kDevCapCoreCount: + caps.set_num_cores(stat.IntOrUintValue()); + break; + case StatType::kDevCapMemoryBandwidth: + caps.set_memory_bandwidth(stat.IntOrUintValue()); + break; + case StatType::kDevCapMemorySize: + caps.set_memory_size_in_bytes(stat.IntOrUintValue()); + break; + case StatType::kDevCapComputeCapMajor: + caps.mutable_compute_capability()->set_major(stat.IntOrUintValue()); + break; + case StatType::kDevCapComputeCapMinor: + caps.mutable_compute_capability()->set_minor(stat.IntOrUintValue()); + break; + } + }); + + return caps; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/device_caps_utils.h b/tensorflow/core/profiler/utils/device_caps_utils.h new file mode 100644 index 00000000000..5ab116ff3a0 --- /dev/null +++ b/tensorflow/core/profiler/utils/device_caps_utils.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAP_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAP_UTILS_H_ + +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane); +DeviceCapabilities GetDeviceCaps(const XPlane& plane); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAP_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/time_utils.h b/tensorflow/core/profiler/utils/time_utils.h index 0a2518b90ff..cef1bda0b76 100644 --- a/tensorflow/core/profiler/utils/time_utils.h +++ b/tensorflow/core/profiler/utils/time_utils.h @@ -22,6 +22,8 @@ namespace tensorflow { namespace profiler { // Converts among different time units. +// NOTE: We use uint64 for picoseconds and nanoseconds, which are used in +// storage, and double for other units that are used in the UI. inline double PicosToNanos(uint64 ps) { return ps / 1E3; } inline double PicosToMicros(uint64 ps) { return ps / 1E6; } inline double PicosToMillis(uint64 ps) { return ps / 1E9; } @@ -29,9 +31,9 @@ inline double PicosToSeconds(uint64 ps) { return ps / 1E12; } inline uint64 NanosToPicos(uint64 ns) { return ns * 1000; } inline double NanosToMicros(uint64 ns) { return ns / 1E3; } inline double MicrosToMillis(double us) { return us / 1E3; } -inline uint64 MillisToPicos(uint64 ms) { return ms * 1000000000; } -inline uint64 MillisToNanos(uint64 ms) { return ms * 1000000; } -inline double MillisToSeconds(uint64 ms) { return ms / 1E3; } +inline uint64 MillisToPicos(double ms) { return ms * 1E9; } +inline uint64 MillisToNanos(double ms) { return ms * 1E6; } +inline double MillisToSeconds(double ms) { return ms / 1E3; } inline uint64 SecondsToNanos(double s) { return s * 1E9; } } // namespace profiler diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h index ded2b353c2c..2504f4b5c48 100644 --- a/tensorflow/core/profiler/utils/xplane_builder.h +++ b/tensorflow/core/profiler/utils/xplane_builder.h @@ -34,7 +34,7 @@ namespace profiler { class XPlaneBuilder; -template +template class XStatsBuilder { public: explicit XStatsBuilder(T* stats_owner, XPlaneBuilder* stats_metadata_owner) @@ -79,7 +79,19 @@ class XStatsBuilder { proto.SerializeToString(bytes); } - void AddStat(const XStatMetadata& key, const XStat& stat, const XPlane& src); + void AddStat(const XStatMetadata& key, const XStat& stat, const XPlane& src) { + if (stat.value_case() == XStat::kRefValue) { + const auto& stat_metadata_map = src.stat_metadata(); + const auto it = stat_metadata_map.find(stat.ref_value()); + if (TF_PREDICT_TRUE(it != stat_metadata_map.end())) { + AddStatRefValue(key, it->second.name()); + } + } else { + XStat* new_stat = stats_owner_->add_stats(); + *new_stat = stat; + new_stat->set_metadata_id(key.id()); + } + } XStat* FindOrAddMutableStat(int64 metadata_id) { for (auto& stat : *stats_owner_->mutable_stats()) { @@ -102,7 +114,7 @@ class XStatsBuilder { } else if (absl::SimpleAtod(value, &double_value)) { AddStatValue(metadata, double_value); } else { - AddStatValue(metadata, value); + AddStatRefValue(metadata, value); } } void ReserveStats(size_t num_stats) { @@ -116,6 +128,8 @@ class XStatsBuilder { return stat; } + void AddStatRefValue(const XStatMetadata& metadata, absl::string_view value); + T* stats_owner_; XPlaneBuilder* stats_metadata_owner_; }; @@ -278,24 +292,12 @@ class XPlaneBuilder : public XStatsBuilder { absl::flat_hash_map lines_by_id_; }; -template -void XStatsBuilder::AddStat(const XStatMetadata& key, const XStat& stat, - const XPlane& src) { - if (stat.value_case() == XStat::kRefValue) { - const auto& stat_metadata_map = src.stat_metadata(); - const auto it = stat_metadata_map.find(stat.ref_value()); - if (TF_PREDICT_FALSE(it == stat_metadata_map.end())) { - // the reference value in stat is not found in XStatMetadata from src. - return; - } - XStatMetadata* value = - stats_metadata_owner_->GetOrCreateStatMetadata(it->second.name()); - AddStatValue(key, *value); - } else { - XStat* new_stat = stats_owner_->add_stats(); - *new_stat = stat; - new_stat->set_metadata_id(key.id()); - } +template +void XStatsBuilder::AddStatRefValue(const XStatMetadata& metadata, + absl::string_view value) { + const XStatMetadata* ref_value = + stats_metadata_owner_->GetOrCreateStatMetadata(value); + AddStatValue(metadata, *ref_value); } } // namespace profiler diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index 78fd7af584c..858dd7a99ba 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -112,6 +112,10 @@ const HostEventTypeMap& GetHostEventTypeMap() { // Batching related. {"BatchingSessionRun", kBatchingSessionRun}, {"ProcessBatch", kProcessBatch}, + {"ConcatInputTensors", kConcatInputTensors}, + {"MergeInputTensors", kMergeInputTensors}, + {"ScheduleWithoutSplit", kScheduleWithoutSplit}, + {"ScheduleWithSplit", kScheduleWithSplit}, // JAX related. {"LocalExecutable::ExecuteOnLocalDevices", kExecuteOnLocalDevices}, // GPU related. @@ -202,6 +206,10 @@ const StatTypeMap& GetStatTypeMap() { {"memory_size", kDevCapMemorySize}, {"compute_cap_major", kDevCapComputeCapMajor}, {"compute_cap_minor", kDevCapComputeCapMinor}, + // Batching related. + {"batch_size_after_padding", kBatchSizeAfterPadding}, + {"padding_amount", kPaddingAmount}, + {"batching_input_task_size", kBatchingInputTaskSize}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 8a9ac4fd278..dd8b4fe5140 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -103,6 +103,10 @@ enum HostEventType { // Batching related. kBatchingSessionRun, kProcessBatch, + kConcatInputTensors, + kMergeInputTensors, + kScheduleWithoutSplit, + kScheduleWithSplit, // JAX related. kExecuteOnLocalDevices, // GPU related. @@ -191,7 +195,11 @@ enum StatType { kDevCapMemorySize, kDevCapComputeCapMajor, kDevCapComputeCapMinor, - kLastStatType = kDevCapComputeCapMinor, + // Batching related. + kBatchSizeAfterPadding, + kPaddingAmount, + kBatchingInputTaskSize, + kLastStatType = kBatchingInputTaskSize, }; inline std::string GpuPlaneName(int32 device_ordinal) { diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc index 1389af2f16a..4af4fb79491 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.cc +++ b/tensorflow/core/profiler/utils/xplane_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/timespan.h" @@ -34,6 +35,37 @@ namespace tensorflow { namespace profiler { namespace { +// Returns the index of the first element in array for which pred is true. +// Returns -1 if no such element is found. +template +int FindIf(const protobuf::RepeatedPtrField& array, Pred&& pred) { + for (int i = 0; i < array.size(); ++i) { + if (pred(&array.Get(i))) return i; + } + return -1; +} + +// Removes the given element from array. +template +void Remove(protobuf::RepeatedPtrField* array, const T* elem) { + int i = FindIf(*array, [elem](const T* e) { return elem == e; }); + if (i == -1) return; + for (; i < array->size() - 1; ++i) { + array->SwapElements(i + 1, i); + } + array->RemoveLast(); +} + +template +void RemoveIf(protobuf::RepeatedPtrField* array, Pred&& pred) { + int i = FindIf(*array, pred); + if (i == -1) return; + for (int j = i + 1; j < array->size(); ++j) { + if (!pred(&array->Get(j))) array->SwapElements(j, i++); + } + array->DeleteSubrange(i, array->size() - i); +} + // Creates a Timespan from an XEvent. // WARNING: This should only be used when comparing events from the same XLine. Timespan XEventTimespan(const XEvent& event) { @@ -43,17 +75,15 @@ Timespan XEventTimespan(const XEvent& event) { } // namespace const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) { - for (const XPlane& plane : space.planes()) { - if (plane.name() == name) return &plane; - } - return nullptr; + int i = FindIf(space.planes(), + [name](const XPlane* plane) { return plane->name() == name; }); + return (i != -1) ? &space.planes(i) : nullptr; } XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name) { - for (XPlane& plane : *space->mutable_planes()) { - if (plane.name() == name) return &plane; - } - return nullptr; + int i = FindIf(space->planes(), + [name](const XPlane* plane) { return plane->name() == name; }); + return (i != -1) ? space->mutable_planes(i) : nullptr; } XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name) { @@ -112,29 +142,19 @@ void AddOrUpdateStrStat(int64 metadata_id, absl::string_view value, stat->set_str_value(std::string(value)); } -void RemovePlaneWithName(XSpace* space, absl::string_view name) { - auto* planes = space->mutable_planes(); - planes->erase( - std::remove_if(planes->begin(), planes->end(), - [&](const XPlane& plane) { return plane.name() == name; }), - planes->end()); +void RemovePlane(XSpace* space, const XPlane* plane) { + DCHECK(plane != nullptr); + Remove(space->mutable_planes(), plane); } void RemoveEmptyPlanes(XSpace* space) { - auto* planes = space->mutable_planes(); - planes->erase(std::remove_if(planes->begin(), planes->end(), - [&](const XPlane& plane) { - return plane.lines_size() == 0; - }), - planes->end()); + RemoveIf(space->mutable_planes(), + [&](const XPlane* plane) { return plane->lines().empty(); }); } void RemoveEmptyLines(XPlane* plane) { - auto* lines = plane->mutable_lines(); - lines->erase(std::remove_if( - lines->begin(), lines->end(), - [&](const XLine& line) { return line.events_size() == 0; }), - lines->end()); + RemoveIf(plane->mutable_lines(), + [&](const XLine* line) { return line->events().empty(); }); } bool XEventsComparator::operator()(const XEvent* a, const XEvent* b) const { diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index eab2c8a8858..77abb2c53d7 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -50,7 +50,8 @@ void AddOrUpdateIntStat(int64 metadata_id, int64 value, void AddOrUpdateStrStat(int64 metadata_id, absl::string_view value, tensorflow::profiler::XEvent* event); -void RemovePlaneWithName(XSpace* space, absl::string_view name); +void RemovePlane(XSpace* space, const XPlane* plane); + void RemoveEmptyPlanes(XSpace* space); void RemoveEmptyLines(XPlane* plane); diff --git a/tensorflow/core/profiler/utils/xplane_utils_test.cc b/tensorflow/core/profiler/utils/xplane_utils_test.cc index 04e06fcb05b..21c87b5c872 100644 --- a/tensorflow/core/profiler/utils/xplane_utils_test.cc +++ b/tensorflow/core/profiler/utils/xplane_utils_test.cc @@ -51,23 +51,28 @@ TEST(XPlaneUtilsTest, IsNestedTest) { EXPECT_FALSE(IsNested(event, not_parent)); } -TEST(XPlaneUtilsTest, RemovePlaneWithName) { +TEST(XPlaneUtilsTest, AddAndRemovePlanes) { XSpace space; - RemovePlaneWithName(&space, "non-exist"); - EXPECT_EQ(space.planes_size(), 0); - space.add_planes()->set_name("p1"); - space.add_planes()->set_name("p2"); - space.add_planes()->set_name("p3"); - RemovePlaneWithName(&space, "non-exist"); - EXPECT_EQ(space.planes_size(), 3); - RemovePlaneWithName(&space, "p2"); + auto* p1 = FindOrAddMutablePlaneWithName(&space, "p1"); + EXPECT_EQ(p1, FindPlaneWithName(space, "p1")); + auto* p2 = FindOrAddMutablePlaneWithName(&space, "p2"); + EXPECT_EQ(p2, FindPlaneWithName(space, "p2")); + auto* p3 = FindOrAddMutablePlaneWithName(&space, "p3"); + EXPECT_EQ(p3, FindPlaneWithName(space, "p3")); + + // Removing a plane does not invalidate pointers to other planes. + + RemovePlane(&space, p2); EXPECT_EQ(space.planes_size(), 2); - RemovePlaneWithName(&space, "p1"); + EXPECT_EQ(p1, FindPlaneWithName(space, "p1")); + EXPECT_EQ(p3, FindPlaneWithName(space, "p3")); + + RemovePlane(&space, p1); EXPECT_EQ(space.planes_size(), 1); - RemovePlaneWithName(&space, "p1"); - EXPECT_EQ(space.planes_size(), 1); - RemovePlaneWithName(&space, "p3"); + EXPECT_EQ(p3, FindPlaneWithName(space, "p3")); + + RemovePlane(&space, p3); EXPECT_EQ(space.planes_size(), 0); } diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 7a0aa16e2c4..3dcd2cd48d0 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -37,4 +37,7 @@ message WorkerConfig { string worker_address = 4; // How often the worker should heartbeat to the master. int64 heartbeat_interval_ms = 5; + // How long to retry requests to the dispatcher before giving up and reporting + // an error. + int64 dispatcher_timeout_ms = 6; } diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index a5b4cfbe823..8df58683ead 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -76,6 +76,9 @@ message SavedUserObject { string identifier = 1; // Version information from the producer of this SavedUserObject. VersionDef version = 2; + // Deprecated! At the time of deprecation, Keras was the only user of this + // field, and its saving and loading code will be updated shortly. + // Please save your application-specific metadata to separate file // Initialization-related metadata. string metadata = 3; } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 5556ec07c4f..a0ef352232d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 553 // Updated: 2020/10/13 +#define TF_GRAPH_DEF_VERSION 562 // Updated: 2020/10/22 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index 40f9353beb4..b7c71dc5cba 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -326,6 +326,28 @@ Status RemoveIdentityNodesForArgRetval(Graph* g) { return Status::OK(); } +// Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when +// 'additional_per_replicate_inputs' are added to the inputs of `xla_node`. +Status UpdateMirroredVariableIndices(int additional_per_replica_inputs, + Node* xla_node) { + std::vector mirrored_variable_indices; + if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) != + nullptr) { + TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), + TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR, + &mirrored_variable_indices)); + } + + if (!mirrored_variable_indices.empty()) { + for (int i = 0; i < mirrored_variable_indices.size(); ++i) + mirrored_variable_indices[i] += additional_per_replica_inputs; + xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR); + xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR, + mirrored_variable_indices); + } + return Status::OK(); +} + // Move outside compilation nodes at the beginning of XLA computation to host. // For XLA computation graph, we will add new _Arg nodes to replace those // outside compilation nodes. @@ -545,6 +567,9 @@ Status MoveHeadOutsideCompilationToHost( xla_node->ClearAttr("Tinputs"); xla_node->AddAttr("Tinputs", new_input_types); + TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices( + /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node)); + int new_variable_start_index = num_new_per_replica_input_types / num_replicas + num_distributed_vars + broadcast_input_types.size(); diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index 29ec8701a37..6a0bab4a0a7 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -111,23 +111,6 @@ xla::StatusOr> SerializeCacheEntryToBufferSlices( } header.set_is_empty(false); - HostComputeMetadataSerializedProto host_compute_metadata; - auto cleanup_host_compute_metadata = - xla::MakeCleanup([&host_compute_metadata]() { - if (host_compute_metadata.size > 0) { - stream_executor::tpu::SerializedProto_Free(host_compute_metadata); - } - }); - Status get_host_compute_metadata_status = - tpu_program_group->SerializeHostComputeMetadata(cache_entry.core_index(), - &host_compute_metadata); - if (!get_host_compute_metadata_status.ok()) { - return errors::Internal("Failed to serialize host compute metadata."); - } - if (!header.mutable_host_compute_metadata()->ParseFromArray( - host_compute_metadata.bytes, host_compute_metadata.size)) { - return errors::Internal("Failed to deserialize host compute metadata."); - } bool may_modify_variables = tpu_program_group->may_modify_variables(cache_entry.core_index()); diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index 73669c21f1e..eeb396349bb 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -497,26 +497,37 @@ Status TpuCompileOpKernelCommon::OptimizeGraph( opts.set_do_function_inlining(true); opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding); GraphOptimizer optimizer(opts); - // Performs a first function inlining pass before shape inference, since - // otherwise shape inference can't see inside functions and a comprehensive - // shape_map, including function ops, is needed to constant-propagate Shape - // Ops below. - GraphOptimizer::Options optimizer_opts; - optimizer_opts.inline_multi_device_functions = true; - optimizer_opts.inline_impl_selection_group_functions = true; - optimizer_opts.inline_with_single_device_body_placer = true; - optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts); + { + // Performs a first function inlining pass before shape inference, since + // otherwise shape inference can't see inside functions and a comprehensive + // shape_map, including function ops, is needed to constant-propagate Shape + // Ops below. + GraphOptimizer::Options optimizer_opts; + optimizer_opts.inline_multi_device_functions = true; + optimizer_opts.inline_impl_selection_group_functions = true; + optimizer_opts.inline_with_single_device_body_placer = true; + // Infer shapes for each node in the computation. Shape inference can help + // skip constant folding of large shapes. + GraphShapeInfo shape_info; + TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( + metadata, arg_shapes, graph->get(), flr, &shape_info)); + // Converts the GraphShapeInfo into the form needed by the constant-folding + // pass of the optimizer. + std::unordered_map> shape_map; + ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); + optimizer_opts.shape_map = &shape_map; + optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts); + } - // Infer shapes for each node in the computation. - GraphShapeInfo shape_info; - TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( - metadata, arg_shapes, graph->get(), flr, &shape_info)); - - // Converts the GraphShapeInfo into the form needed by the constant-folding - // pass of the optimizer. - std::unordered_map> shape_map; - ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); - optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map); + { + // Infer shapes for each node in the computation. + GraphShapeInfo shape_info; + TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( + metadata, arg_shapes, graph->get(), flr, &shape_info)); + std::unordered_map> shape_map; + ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); + optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map); + } TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld)); diff --git a/tensorflow/core/tpu/kernels/tpu_program_c_api.h b/tensorflow/core/tpu/kernels/tpu_program_c_api.h index 1b35a8a036b..d6e46a7c419 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_program_c_api.h @@ -105,11 +105,6 @@ TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( const XLA_TpuProgram* tpu_program, CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status); -// Gets host transfer metadata proto from a `tpu_program`. -TFTPU_CAPI_EXPORT void TpuProgram_SerializeHostComputeMetadata( - const XLA_TpuProgram* tpu_program, - HostComputeMetadataSerializedProto* host_compute_metadata, - SE_Status* status); // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( @@ -132,7 +127,6 @@ struct TfTpu_TpuProgramApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); - TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeHostComputeMetadata); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); }; diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index cd2d5bda98c..abc53cfc0eb 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -333,16 +333,5 @@ Status TpuProgramGroup::SerializeCompilerMetadata( tpu_programs_[index], compiler_metadata, status.c_status); return status.status(); } - -Status TpuProgramGroup::SerializeHostComputeMetadata( - int index, - HostComputeMetadataSerializedProto* host_compute_metadata) const { - CHECK_GE(index, 0); - CHECK_LT(index, tpu_programs_.size()); - StatusHelper status; - TpuProgramApiFn()->TpuProgram_SerializeHostComputeMetadataFn( - tpu_programs_[index], host_compute_metadata, status.c_status); - return status.status(); -} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index de245340b8a..f824c9202e5 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -77,7 +77,6 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) { TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram); TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable); TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata); - TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeHostComputeMetadata); TFTPU_SET_FN(tpu_program_fn, TpuProgram_DeserializeFromGetTpuProgramResponseProto); diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index b9342638592..e8179d33d7e 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -668,7 +668,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/framework:bounds_check", "//third_party/eigen3", ], ) @@ -703,6 +703,9 @@ tf_cuda_only_cc_test( srcs = [ "gpu_kernel_helper_test.cu.cc", ], + tags = [ + "no_cuda_asan", # TODO(b/171342366): re-enable. + ], deps = [ "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -737,7 +740,6 @@ tf_cc_tests( "tensor_slice_writer_test.cc", "work_sharder_test.cc", ], - create_named_test_suite = True, linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 00a9cbaa3d8..ce83f2d2fb3 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -132,51 +132,61 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, } // namespace -Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) +Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text, + bool* dst_updated) : name_(name), type_(TYPE_INT32), - int32_hook_([dst](int32 value) { + int32_hook_([dst, dst_updated](int32 value) { *dst = value; + if (dst_updated) *dst_updated = true; return true; }), int32_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) +Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text, + bool* dst_updated) : name_(name), type_(TYPE_INT64), - int64_hook_([dst](int64 value) { + int64_hook_([dst, dst_updated](int64 value) { *dst = value; + if (dst_updated) *dst_updated = true; return true; }), int64_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const string& usage_text) +Flag::Flag(const char* name, float* dst, const string& usage_text, + bool* dst_updated) : name_(name), type_(TYPE_FLOAT), - float_hook_([dst](float value) { + float_hook_([dst, dst_updated](float value) { *dst = value; + if (dst_updated) *dst_updated = true; return true; }), float_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, bool* dst, const string& usage_text) +Flag::Flag(const char* name, bool* dst, const string& usage_text, + bool* dst_updated) : name_(name), type_(TYPE_BOOL), - bool_hook_([dst](bool value) { + bool_hook_([dst, dst_updated](bool value) { *dst = value; + if (dst_updated) *dst_updated = true; return true; }), bool_default_for_display_(*dst), usage_text_(usage_text) {} -Flag::Flag(const char* name, string* dst, const string& usage_text) +Flag::Flag(const char* name, string* dst, const string& usage_text, + bool* dst_updated) : name_(name), type_(TYPE_STRING), - string_hook_([dst](string value) { + string_hook_([dst, dst_updated](string value) { *dst = std::move(value); + if (dst_updated) *dst_updated = true; return true; }), string_default_for_display_(*dst), diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 928ae8a4e94..3d583a605b5 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -62,11 +62,16 @@ namespace tensorflow { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32* dst, const string& usage_text); - Flag(const char* name, int64* dst, const string& usage_text); - Flag(const char* name, bool* dst, const string& usage_text); - Flag(const char* name, string* dst, const string& usage_text); - Flag(const char* name, float* dst, const string& usage_text); + Flag(const char* name, int32* dst, const string& usage_text, + bool* dst_updated = nullptr); + Flag(const char* name, int64* dst, const string& usage_text, + bool* dst_updated = nullptr); + Flag(const char* name, bool* dst, const string& usage_text, + bool* dst_updated = nullptr); + Flag(const char* name, string* dst, const string& usage_text, + bool* dst_updated = nullptr); + Flag(const char* name, float* dst, const string& usage_text, + bool* dst_updated = nullptr); // These constructors invoke a hook on a match instead of writing to a // specific memory location. The hook may return false to signal a malformed @@ -85,6 +90,8 @@ class Flag { Flag(const char* name, std::function string_hook, string default_value_for_display, const string& usage_text); + bool is_default_initialized() const { return default_initialized_; } + private: friend class Flags; @@ -115,6 +122,7 @@ class Flag { string string_default_for_display_; string usage_text_; + bool default_initialized_ = true; }; class Flags { diff --git a/tensorflow/core/util/matmul_autotune.cc b/tensorflow/core/util/matmul_autotune.cc index 741a78a193f..c30a5d930e7 100644 --- a/tensorflow/core/util/matmul_autotune.cc +++ b/tensorflow/core/util/matmul_autotune.cc @@ -48,4 +48,22 @@ bool MatmulDoFP32ComputationFP16Input() { return value; } +int MatmulMaxAutotuneAlgorithmCount() { + int64 value; + // In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4 + // algorithms for a given configuration, so 10 seems like a reasonable default + // here. + Status status = + ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } + static constexpr const int kMaxValue = std::numeric_limits::max(); + if (value < 1 || value > kMaxValue) { + LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: " + << value << " is not in range [1, " << kMaxValue << "]"; + } + return value; +} + } // namespace tensorflow diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h index 5846cae2fc7..c77d274e781 100644 --- a/tensorflow/core/util/matmul_autotune.h +++ b/tensorflow/core/util/matmul_autotune.h @@ -22,6 +22,7 @@ namespace tensorflow { bool MatmulAutotuneEnable(); bool MatmulDoFP32ComputationFP16Input(); +int MatmulMaxAutotuneAlgorithmCount(); } // namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index fa2b3d25e29..df4a05b61d4 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -89,8 +89,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", ], ) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index f60f4daf071..32df8eccd08 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -14404,6 +14404,61 @@ func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_repli return op.Output(0) } +// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. +type TextLineReaderV2Attr func(optionalAttr) + +// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. +// +// value: Number of lines to skip from the beginning of every file. +// If not specified, defaults to 0 +func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["skip_header_lines"] = value + } +} + +// TextLineReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TextLineReaderV2Container(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TextLineReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the lines of a file delimited by '\n'. +// +// Returns The handle to reference the Reader. +func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TextLineReaderV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Saves the input tensors to disk. // // The size of `tensor_names` must match the number of tensors in `data`. `data[i]` @@ -16012,6 +16067,12 @@ func DecodeImageExpandAnimations(value bool) DecodeImageAttr { // False, in which case the op will return 3-dimensional tensors and will truncate // animated GIF files to the first frame. // +// *NOTE*: If the first frame of an animated GIF does not occupy the entire +// canvas (maximum frame width x maximum frame height), then it fills the +// unoccupied areas (in the first frame) with zeros (black). For frames after the +// first frame that does not occupy the entire canvas, it uses the previous +// frame to fill the unoccupied areas. +// // Arguments: // contents: 0-D. The encoded image bytes. // @@ -30616,61 +30677,6 @@ func GRUBlockCellGrad(scope *Scope, x tf.Output, h_prev tf.Output, w_ru tf.Outpu return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. -type TextLineReaderV2Attr func(optionalAttr) - -// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. -// -// value: Number of lines to skip from the beginning of every file. -// If not specified, defaults to 0 -func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["skip_header_lines"] = value - } -} - -// TextLineReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TextLineReaderV2Container(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TextLineReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the lines of a file delimited by '\n'. -// -// Returns The handle to reference the Reader. -func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TextLineReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Encode audio data using the WAV file format. // // This operation will generate a string suitable to be saved out to create a .wav @@ -30875,6 +30881,37 @@ func StatefulStandardNormalV2(scope *Scope, resource tf.Output, algorithm tf.Out return op.Output(0) } +// Helper used to compute the gradient for `RaggedTensorToVariant`. +// +// Computes the gradient for the dense_values input to the RaggedTensorToVariant +// op, given the variant-encoded ragged gradients of the outputs, along with +// the outer row-splits and the shape of the dense-values that were provided as +// inputs to the RaggedTensorToVariant op. +// +// Arguments: +// encoded_ragged_grad: A `variant` Tensor containing encoded `RaggedTensor` gradients. +// row_splits: Outermost row-splits that were used as input to the RaggedTensorToVariant op. +// dense_values_shape: Shape of the dense_values that was used as an input to the +// RaggedTensorToVariant op. +// +// +// Returns Gradient for the dense_values of the RaggedTensorToVariant op. +func RaggedTensorToVariantGradient(scope *Scope, encoded_ragged_grad tf.Output, row_splits tf.Output, dense_values_shape tf.Output, Tvalues tf.DataType) (dense_values_grad tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "RaggedTensorToVariantGradient", + Input: []tf.Input{ + encoded_ragged_grad, row_splits, dense_values_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. type ResourceSparseApplyFtrlAttr func(optionalAttr) diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 7d22a78bcb4..fabfa19f638 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,6 +15,20 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -23,7 +37,6 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; - import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -44,21 +57,6 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; -import com.squareup.javapoet.WildcardTypeName; - /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -109,7 +107,7 @@ public final class OperatorProcessor extends AbstractProcessor { // If there are no annotated elements, claim the annotation but do nothing. if (annotated.size() == 0) { - return true; + return false; } // This processor has to aggregate all op classes in one round, as it generates a single Ops @@ -124,25 +122,25 @@ public final class OperatorProcessor extends AbstractProcessor { + "One reason this can happen is if other annotation processors generate\n" + "new @Operator source files."); } - return true; + return false; } // Collect all classes tagged with our annotation. Multimap groupedMethods = HashMultimap.create(); if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { - return true; + return false; } // Nothing to do when there are no tagged classes. if (groupedMethods.isEmpty()) { - return true; + return false; } // Validate operator classes and generate Op API. writeApi(groupedMethods); hasRun = true; - return true; + return false; } @Override @@ -410,7 +408,8 @@ public final class OperatorProcessor extends AbstractProcessor { .returns(T_OPS) .addStatement("return new Ops(scope.withControlDependencies(controls))") .addJavadoc( - "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + "Returns an API that adds operations to the graph with the provided control" + + " dependencies.\n\n" + "@see {@link $T#withControlDependencies(Iterable>)}\n", T_SCOPE) .build()); @@ -457,8 +456,10 @@ public final class OperatorProcessor extends AbstractProcessor { .returns(T_OPS) .addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION) .addJavadoc( - "Creates an API for building operations in the default eager execution environment\n\n" - + "

Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n") + "Creates an API for building operations in the default eager execution" + + " environment\n\n" + + "

Invoking this method is equivalent to {@code" + + " Ops.create(EagerSession.getDefault())}.\n") .build()); return opsBuilder.build(); diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index c4a0b5d525d..597f81194cd 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -309,8 +309,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [ "//tensorflow/lite:stderr_reporter", @@ -325,8 +324,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [ ":minimal_logging", @@ -341,8 +339,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [ "//tensorflow/lite:mutable_op_resolver", @@ -357,8 +354,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + TFLITE_DEFAULT_COPTS, visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [ ":util", @@ -628,6 +624,15 @@ cc_test( ], ) +cc_test( + name = "stderr_reporter_test", + srcs = ["stderr_reporter_test.cc"], + deps = [ + ":stderr_reporter", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "util", srcs = ["util.cc"], diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 77a79f11559..a75728e8a9d 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -381,7 +381,7 @@ add_library(tensorflow::tensorflowlite ALIAS tensorflow-lite) # Benchmark Tool populate_source_vars("${TFLITE_SOURCE_DIR}/tools/benchmark" TFLITE_BENCHMARK_SRCS - FILTER "(_test|_plus_flex_main|_performance_options_main)\\.cc$" + FILTER "(_test|_plus_flex_main|_performance_options.*)\\.cc$" ) list(APPEND TFLITE_BENCHMARK_SRCS ${TF_SOURCE_DIR}/core/util/stats_calculator.cc @@ -402,6 +402,13 @@ list(APPEND TFLITE_BENCHMARK_LIBS ${CMAKE_DL_LIBS} ) +# TODO(b/171007016): Enable performance options on Windows. +if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Windows") + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_performance_options.cc + ) +endif() + if(TFLITE_ENABLE_XNNPACK) list(APPEND TFLITE_BENCHMARK_SRCS ${TFLITE_SOURCE_DIR}/tools/delegates/xnnpack_delegate_provider.cc diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 57c04be11f0..0f731c43577 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -169,6 +169,33 @@ def tflite_cc_shared_object( def tf_to_tflite(name, src, options, out): """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TFLite Converter. + out: name of the output flatbuffer file. + """ + + toco_cmdline = " ".join([ + "$(location //tensorflow/lite/python:tflite_convert)", + "--experimental_new_converter", + ("--graph_def_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options) + native.genrule( + name = name, + srcs = [src], + outs = [out], + cmd = toco_cmdline, + tools = ["//tensorflow/lite/python:tflite_convert"] + tf_binary_additional_srcs(), + ) + +def DEPRECATED_tf_to_tflite(name, src, options, out): + """DEPRECATED Convert a frozen tensorflow graphdef to TF Lite's flatbuffer, using toco. + + Please use tf_to_tflite instead. + TODO(b/138396996): Migrate away from this deprecated rule. + Args: name: Name of rule. src: name of the input graphdef file. diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index a37607f6260..35952090e6c 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -155,6 +155,7 @@ typedef enum { kTfLiteBuiltinSegmentSum = 125, kTfLiteBuiltinBatchMatmul = 126, kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127, + kTfLiteBuiltinCumsum = 128, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index e205f075b43..a511e51b5bf 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -465,6 +465,11 @@ typedef struct { int body_subgraph_index; } TfLiteWhileParams; +typedef struct { + bool exclusive; + bool reverse; +} TfLiteCumsumParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 8917c254825..e04e1a12cd4 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -46,8 +46,17 @@ extern "C" { typedef enum TfLiteStatus { kTfLiteOk = 0, + + // Generally referring to an error in the runtime (i.e. interpreter) kTfLiteError = 1, + + // Generally referring to an error from a TfLiteDelegate itself. kTfLiteDelegateError = 2, + + // Generally referring to an error in applying a delegate due to + // incompatibility between runtime and delegate, e.g., this error is returned + // when trying to apply a TfLite delegate onto a model graph that's already + // immutable. kTfLiteApplicationError = 3 } TfLiteStatus; diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 93661996e9e..38b2e295da2 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -49,8 +49,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + micro_copts(), visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [ ":error_reporter", @@ -68,8 +67,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts() + micro_copts(), visibility = [ - "//tensorflow/lite:__subpackages__", - "//tensorflow_lite_support:__subpackages__", + "//visibility:public", ], deps = [], ) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 77621c3f2fd..ea381801505 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -761,6 +761,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + case BuiltinOperator_CUMSUM: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* cumsum_params = op->builtin_options_as_CumsumOptions()) { + params->exclusive = cumsum_params->exclusive(); + params->reverse = cumsum_params->reverse(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } // Below are the ops with no builtin_data structure. case BuiltinOperator_BATCH_TO_SPACE_ND: // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are @@ -1431,9 +1441,6 @@ TfLiteStatus ParseReshape(const Operator* op, ErrorReporter* error_reporter, if (schema_params != nullptr) { const flatbuffers::Vector* new_shape = schema_params->new_shape(); - // TODO(b/147203660): We need to figure out when dynamic reshape - // (new_shape is a tensor) happens, why the option is not a nullptr. - // But nonethless, we should only copy when new_shape is not a nullptr. if (new_shape != nullptr) { TF_LITE_ENSURE_STATUS( FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 2b9246a1100..cece4ecba87 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1526,7 +1526,7 @@ TfLiteStatus Subgraph::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { ReportError( "Attempting to use a delegate that only supports static-sized " "tensors with a graph that has dynamic-sized tensors."); - return kTfLiteError; + return kTfLiteApplicationError; } } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index b94d1a0b2bc..ed3b55ce630 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -567,7 +567,8 @@ class Subgraph { // delegate*. The Subgraph has been restored to its pre-delegation state. // NOTE: This reverts all delegates previously applied to the Subgraph. // 3. kTfLiteApplicationError : Delegation failed to be applied due to the - // state that the TfLite runtime is in. However, the Subgraph is still in a + // incompatibility with the TfLite runtime, e.g., the model graph is already + // immutable when applying the delegate. However, the Subgraph is still in a // invokable state. // 4. kTfLiteError: Unexpected/runtime failure. TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index c616b081829..63171348b74 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -55,6 +55,7 @@ cc_library( ":cl_device", ":gpu_object", ":opencl_wrapper", + ":serialization_cc_fbs", ":tensor_type", ":util", "//tensorflow/lite/delegates/gpu/common:access_type", @@ -358,6 +359,7 @@ cc_library( deps = [ ":cl_context", ":opencl_wrapper", + ":serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:status", @@ -366,19 +368,30 @@ cc_library( cc_library( name = "inference_context", - srcs = ["inference_context.cc"], - hdrs = ["inference_context.h"], + srcs = [ + "inference_context.cc", + "serialization.cc", + ], + hdrs = [ + "inference_context.h", + "serialization.h", + ], deps = [ + ":arguments", ":buffer", ":cl_command_queue", + ":cl_context", ":cl_device", ":environment", ":gpu_object", + ":linear_storage", ":model_hints", ":opencl_wrapper", ":precision", + ":serialization_cc_fbs", ":storage_type_util", ":tensor_type", + ":texture2d", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector", "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector", @@ -396,6 +409,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", ], ) @@ -467,6 +481,14 @@ cc_library( ], ) +flatbuffer_cc_library( + name = "serialization_cc_fbs", + srcs = ["serialization.fbs"], + flatc_args = [ + "--scoped-enums", + ], +) + cc_library( name = "storage_type_util", srcs = ["storage_type_util.cc"], diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index dfd8a8680fc..e2135d05b53 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -570,6 +570,56 @@ TensorObjectDef TensorToDef(const Tensor& tensor) { return def; } +CalculationsPrecision GetPrecision(const Environment& env, + const InferenceOptions& options) { + CalculationsPrecision precision; + switch (GetPosition(options, InferencePriority::MAX_PRECISION)) { + case 1: + precision = CalculationsPrecision::F32; + break; + case 2: + precision = CalculationsPrecision::F32_F16; + break; + case 3: + precision = CalculationsPrecision::F16; + break; + default: + precision = CalculationsPrecision::F16; + break; + } + // Increase precision if lower precision is not supported. + if (!env.IsSupported(precision)) { + precision = CalculationsPrecision::F32_F16; + if (!env.IsSupported(precision)) { + precision = CalculationsPrecision::F32; + } + } + return precision; +} + +TensorStorageType GetStorageTypeFromOptions(const Environment& env, + const InferenceOptions& options) { + // Fallback to BUFFER that should be supported by default. + std::vector preferred_storage_types; + if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY, + InferencePriority::MIN_MEMORY_USAGE) == + PriorityImportance::HIGHER) { + preferred_storage_types = {GetFastestStorageType(env.device().GetInfo()), + TensorStorageType::BUFFER}; + } else { + preferred_storage_types = { + GetStorageTypeWithMinimalMemoryConsumption(env.device().GetInfo()), + TensorStorageType::BUFFER}; + } + + for (TensorStorageType storage_type : preferred_storage_types) { + if (env.IsSupported(storage_type)) { + return storage_type; + } + } + return TensorStorageType::UNKNOWN; +} + class InferenceBuilderImpl : public InferenceBuilder { public: explicit InferenceBuilderImpl(Environment* environment) @@ -580,8 +630,9 @@ class InferenceBuilderImpl : public InferenceBuilder { const GraphFloat32& graph) { context_ = absl::make_unique(); InferenceContext::CreateInferenceInfo create_info; - create_info.precision = GetPrecision(options); - create_info.storage_type = GetStorageType(options); + create_info.precision = GetPrecision(*environment_, options); + create_info.storage_type = + GetStorageTypeFromOptions(*environment_, options); if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) { create_info.hints.Add(ModelHints::kReduceKernelsCount); create_info.hints.Add(ModelHints::kFastTuning); @@ -590,6 +641,30 @@ class InferenceBuilderImpl : public InferenceBuilder { } RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_)); +#ifdef CL_DELEGATE_ALLOW_GL + if (env_options.IsGlAware() && + IsGlSharingSupported(environment_->device())) { + gl_interop_fabric_ = absl::make_unique( + env_options.egl_display, environment_); + } + tie_factory_ = absl::make_unique( + environment_, context_.get(), gl_interop_fabric_.get()); +#else + tie_factory_ = + absl::make_unique(environment_, context_.get()); +#endif + + inputs_ = LinkTensors(context_->GetInputIds(), AccessType::READ); + outputs_ = LinkTensors(context_->GetOutputIds(), AccessType::WRITE); + return absl::OkStatus(); + } + + absl::Status Initialize(const InferenceEnvironmentOptions& env_options, + const std::vector& serialized_model) { + context_ = absl::make_unique(); + RETURN_IF_ERROR( + context_->RestoreDeserialized(serialized_model, environment_)); + #ifdef CL_DELEGATE_ALLOW_GL if (env_options.IsGlAware() && IsGlSharingSupported(environment_->device())) { @@ -671,55 +746,6 @@ class InferenceBuilderImpl : public InferenceBuilder { } private: - TensorStorageType GetStorageType(const InferenceOptions& options) const { - // Fallback to BUFFER that should be supported by default. - std::vector preferred_storage_types; - if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY, - InferencePriority::MIN_MEMORY_USAGE) == - PriorityImportance::HIGHER) { - preferred_storage_types = { - GetFastestStorageType(environment_->device().GetInfo()), - TensorStorageType::BUFFER}; - } else { - preferred_storage_types = {GetStorageTypeWithMinimalMemoryConsumption( - environment_->device().GetInfo()), - TensorStorageType::BUFFER}; - } - - for (TensorStorageType storage_type : preferred_storage_types) { - if (environment_->IsSupported(storage_type)) { - return storage_type; - } - } - return TensorStorageType::UNKNOWN; - } - - CalculationsPrecision GetPrecision(const InferenceOptions& options) const { - CalculationsPrecision precision; - switch (GetPosition(options, InferencePriority::MAX_PRECISION)) { - case 1: - precision = CalculationsPrecision::F32; - break; - case 2: - precision = CalculationsPrecision::F32_F16; - break; - case 3: - precision = CalculationsPrecision::F16; - break; - default: - precision = CalculationsPrecision::F16; - break; - } - // Increase precision if lower precision is not supported. - if (!environment_->IsSupported(precision)) { - precision = CalculationsPrecision::F32_F16; - if (!environment_->IsSupported(precision)) { - precision = CalculationsPrecision::F32; - } - } - return precision; - } - // Links internal tensors with external user-facing objects. std::vector LinkTensors(const std::vector& ids, AccessType access) { @@ -840,6 +866,39 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return environment_.Init(); } + absl::Status BuildSerializedModel( + const InferenceOptions& options, GraphFloat32 model, + std::vector* serialized_model) final { + if (!IsValid(options)) { + return absl::InvalidArgumentError("InferenceOptions are invalid."); + } + InferenceOptions resolved_options = options; + ResolveAutoPriority(&resolved_options); + if (environment_.program_cache() && + !options_.serialized_binary_cache.empty()) { + // Ignore returned error. Cache is discarded. + environment_.program_cache() + ->AddSerializedCache(environment_.context(), environment_.device(), + options_.serialized_binary_cache) + .IgnoreError(); + } + + RETURN_IF_ERROR(RunGraphTransforms(&model)); + InferenceContext context; + InferenceContext::CreateInferenceInfo create_info; + create_info.precision = GetPrecision(environment_, options); + create_info.storage_type = GetStorageTypeFromOptions(environment_, options); + if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) { + create_info.hints.Add(ModelHints::kReduceKernelsCount); + create_info.hints.Add(ModelHints::kFastTuning); + } else if (options.usage == InferenceUsage::SUSTAINED_SPEED) { + create_info.hints.Add(ModelHints::kAllowSpecialKernels); + } + RETURN_IF_ERROR(context.InitFromGraph(create_info, model, &environment_, + serialized_model)); + return absl::OkStatus(); + } + absl::Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) final { @@ -865,6 +924,24 @@ class InferenceEnvironmentImpl : public InferenceEnvironment { return absl::OkStatus(); } + absl::Status NewInferenceBuilder( + const std::vector& serialized_model, + std::unique_ptr* builder) final { + if (environment_.program_cache() && + !options_.serialized_binary_cache.empty()) { + // Ignore returned error. Cache is discarded. + environment_.program_cache() + ->AddSerializedCache(environment_.context(), environment_.device(), + options_.serialized_binary_cache) + .IgnoreError(); + } + + auto builder_impl = absl::make_unique(&environment_); + RETURN_IF_ERROR(builder_impl->Initialize(options_, serialized_model)); + *builder = std::move(builder_impl); + return absl::OkStatus(); + } + std::vector GetSerializedBinaryCache() const final { std::vector data; // Is there was a problem, data would be empty. diff --git a/tensorflow/lite/delegates/gpu/cl/api.h b/tensorflow/lite/delegates/gpu/cl/api.h index 826d4f2bc78..65671117522 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.h +++ b/tensorflow/lite/delegates/gpu/cl/api.h @@ -75,6 +75,20 @@ class InferenceEnvironment { public: virtual ~InferenceEnvironment() {} + // Converts GraphFloat32 into intermediate, device-specific representation. + // This serialized_model specific for device and InferenceOptions. + // serialized_model cannot be used with another device or InferenceOptions. + // Loading serialized_model is much faster than loading GraphFloat32. + // serialized_model must be used with appropriate NewInferenceBuilder + // method (see below). + virtual absl::Status BuildSerializedModel( + const InferenceOptions& options, GraphFloat32 model, + std::vector* serialized_model) = 0; + + virtual absl::Status NewInferenceBuilder( + const std::vector& serialized_model, + std::unique_ptr* builder) = 0; + virtual absl::Status NewInferenceBuilder( const InferenceOptions& options, GraphFloat32 model, std::unique_ptr* builder) = 0; diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index 526a09f18f9..7c5e635816e 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -659,56 +659,64 @@ absl::Status Arguments::Bind(cl_kernel kernel, int offset) { std::string Arguments::AddActiveArgument(const std::string& arg_name, bool use_f32_for_halfs) { - if (auto it = int_values_.find(arg_name); it != int_values_.end()) { - int int_index; - if (it->second.active) { - int_index = it->second.offset; - } else { - it->second.active = true; - it->second.offset = shared_int4s_data_.size(); - int_index = it->second.offset; - shared_int4s_data_.push_back(it->second.value); - } - std::string index = std::to_string(int_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - return "shared_int4_" + index + "." + postfixes[int_index % 4]; - } - if (auto it = float_values_.find(arg_name); it != float_values_.end()) { - int float_index; - if (it->second.active) { - float_index = it->second.offset; - } else { - it->second.active = true; - it->second.offset = shared_float4s_data_.size(); - float_index = it->second.offset; - shared_float4s_data_.push_back(it->second.value); - } - std::string index = std::to_string(float_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - return "shared_float4_" + index + "." + postfixes[float_index % 4]; - } - if (auto it = half_values_.find(arg_name); it != half_values_.end()) { - int half_index; - if (it->second.active) { - half_index = it->second.offset; - } else { - it->second.active = true; - if (use_f32_for_halfs) { - it->second.store_as_f32 = true; - it->second.offset = shared_float4s_data_.size(); - shared_float4s_data_.push_back(it->second.value); + { + auto it = int_values_.find(arg_name); + if (it != int_values_.end()) { + int int_index; + if (it->second.active) { + int_index = it->second.offset; } else { - it->second.offset = shared_half4s_data_.size(); - shared_half4s_data_.push_back(it->second.value); + it->second.active = true; + it->second.offset = shared_int4s_data_.size(); + int_index = it->second.offset; + shared_int4s_data_.push_back(it->second.value); } - half_index = it->second.offset; + std::string index = std::to_string(int_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_int4_" + index + "." + postfixes[int_index % 4]; } - std::string index = std::to_string(half_index / 4); - std::string postfixes[4] = {"x", "y", "z", "w"}; - if (it->second.store_as_f32) { - return "(half)(shared_float4_" + index + "." + postfixes[half_index % 4] + - ")"; - } else { + } + { + auto it = float_values_.find(arg_name); + if (it != float_values_.end()) { + int float_index; + if (it->second.active) { + float_index = it->second.offset; + } else { + it->second.active = true; + it->second.offset = shared_float4s_data_.size(); + float_index = it->second.offset; + shared_float4s_data_.push_back(it->second.value); + } + std::string index = std::to_string(float_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_float4_" + index + "." + postfixes[float_index % 4]; + } + } + { + auto it = half_values_.find(arg_name); + if (it != half_values_.end()) { + int half_index; + if (it->second.active) { + half_index = it->second.offset; + } else { + it->second.active = true; + if (use_f32_for_halfs) { + it->second.store_as_f32 = true; + it->second.offset = shared_float4s_data_.size(); + shared_float4s_data_.push_back(it->second.value); + } else { + it->second.offset = shared_half4s_data_.size(); + shared_half4s_data_.push_back(it->second.value); + } + half_index = it->second.offset; + } + std::string index = std::to_string(half_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + if (it->second.store_as_f32) { + return "(half)(shared_float4_" + index + "." + + postfixes[half_index % 4] + ")"; + } return "shared_half4_" + index + "." + postfixes[half_index % 4]; } } @@ -748,24 +756,38 @@ void Arguments::ResolveObjectNames(const std::string& object_name, } } +GPUObjectDescriptor* Arguments::GetObjectDescriptor( + const std::string& object_name) const { + { + auto it = object_refs_.find(object_name); + if (it != object_refs_.end()) { + return it->second.descriptor.get(); + } + } + { + auto it = objects_.find(object_name); + if (it != objects_.end()) { + return it->second.descriptor.get(); + } + } + return nullptr; +} + absl::Status Arguments::ResolveSelector( const std::map& linkables, const std::string& object_name, const std::string& selector, const std::vector& args, const std::vector& template_args, std::string* result) { - const GPUObjectDescriptor* desc_ptr; - if (auto it = object_refs_.find(object_name); it != object_refs_.end()) { - desc_ptr = it->second.descriptor.get(); - } else if (auto it = objects_.find(object_name); it != objects_.end()) { - desc_ptr = it->second.descriptor.get(); - } else { + const GPUObjectDescriptor* desc_ptr = GetObjectDescriptor(object_name); + if (!desc_ptr) { return absl::NotFoundError( absl::StrCat("No object with name - ", object_name)); } auto names = desc_ptr->GetGPUResources().GetNames(); const auto* tensor_desc = dynamic_cast(desc_ptr); if (tensor_desc && selector == "Write") { - if (auto it = linkables.find(object_name); it != linkables.end()) { + auto it = linkables.find(object_name); + if (it != linkables.end()) { if (desc_ptr->GetAccess() != AccessType::WRITE && desc_ptr->GetAccess() != AccessType::READ_WRITE) { return absl::FailedPreconditionError(absl::StrCat( diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 2f87f8ae6a0..a5435c4fc2f 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -77,6 +78,11 @@ class Arguments : public ArgumentsBinder { ~Arguments() override = default; private: + friend flatbuffers::Offset Encode( + const Arguments& args, flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, const data::Arguments* fb_args, + Arguments* args); + void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); void AddImage2DArray(const std::string& name, @@ -119,6 +125,9 @@ class Arguments : public ArgumentsBinder { const std::vector& member_names, std::string* code); + GPUObjectDescriptor* GetObjectDescriptor( + const std::string& object_name) const; + static constexpr char kArgsPrefix[] = "args."; struct IntValue { diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index 297a5f70858..abd77a4489b 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -164,6 +165,10 @@ class GPUObjectDescriptor { AccessType GetAccess() const { return access_type_; } protected: + friend flatbuffers::Offset Encode( + const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder); + friend void Decode(const data::GPUObjectDescriptor* fb_obj, + GPUObjectDescriptor* obj); mutable std::map state_vars_; AccessType access_type_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index b834bbfffef..ca0c0319f54 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -153,7 +153,7 @@ CLNode& CLNode::operator=(CLNode&& node) { absl::Status InferenceContext::InitFromGraph( const CreateInferenceInfo& create_info, const GraphFloat32& graph, - Environment* env) { + Environment* env, std::vector* serialized_model) { CreationContext creation_context; creation_context.device = env->GetDevicePtr(); creation_context.context = &env->context(); @@ -182,10 +182,6 @@ absl::Status InferenceContext::InitFromGraph( RETURN_IF_ERROR(Compile(creation_context)); RETURN_IF_ERROR(UpdateParams()); - for (auto& node : nodes_) { - node.operation->args_.ReleaseCPURepresentation(); - } - TuningParameters tuning_parameters; tuning_parameters.queue = env->profiling_queue(); tuning_parameters.info = &env->device().info_; @@ -201,14 +197,54 @@ absl::Status InferenceContext::InitFromGraph( } } RETURN_IF_ERROR(Tune(tuning_parameters)); + + if (serialized_model) { + flatbuffers::FlatBufferBuilder builder; + auto encoded_fb = Encode(*this, &builder); + data::FinishInferenceContextBuffer(builder, encoded_fb); + serialized_model->resize(builder.GetSize()); + std::memcpy(serialized_model->data(), builder.GetBufferPointer(), + builder.GetSize()); + } + for (auto& node : nodes_) { + node.operation->args_.ReleaseCPURepresentation(); + } + return absl::OkStatus(); +} + +absl::Status InferenceContext::RestoreDeserialized( + const std::vector& serialized_model, Environment* env) { + flatbuffers::Verifier verifier(serialized_model.data(), + serialized_model.size()); + if (!data::VerifyInferenceContextBuffer(verifier)) { + return absl::DataLossError("Deserialization failed."); + } + auto decoded_fb = data::GetInferenceContext(serialized_model.data()); + RETURN_IF_ERROR(Decode(&env->context(), decoded_fb, this)); + + CreationContext creation_context; + creation_context.device = env->GetDevicePtr(); + creation_context.context = &env->context(); + creation_context.queue = env->queue(); + creation_context.cache = env->program_cache(); + + RETURN_IF_ERROR(AllocateMemory(creation_context.context)); + BindMemoryToOperations(); + for (auto& node : nodes_) { + RETURN_IF_ERROR(node.operation->CompileDeserialized(creation_context)); + } + RETURN_IF_ERROR(UpdateParams()); + for (auto& node : nodes_) { + node.operation->args_.ReleaseCPURepresentation(); + } return absl::OkStatus(); } absl::Status InferenceContext::InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, - Environment* env) { + Environment* env, std::vector* serialized_model) { RETURN_IF_ERROR(RunGraphTransforms(graph)); - RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env)); + RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env, serialized_model)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index da687ffa3b5..ec8055ebcde 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/model_hints.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -65,14 +66,15 @@ class InferenceContext { }; absl::Status InitFromGraph(const CreateInferenceInfo& create_info, - const GraphFloat32& graph, Environment* env); + const GraphFloat32& graph, Environment* env, + std::vector* serialized_model = nullptr); // Applies OpenCL-specific transformations to the graph before the // initialization. These transformations are either impossible or useless in // other backends. absl::Status InitFromGraphWithTransforms( const CreateInferenceInfo& create_info, GraphFloat32* graph, - Environment* env); + Environment* env, std::vector* serialized_model = nullptr); absl::Status AddToQueue(CLCommandQueue* queue); absl::Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result); @@ -92,9 +94,19 @@ class InferenceContext { const std::vector& GetInputIds() const { return input_ids_; } const std::vector& GetOutputIds() const { return output_ids_; } + absl::Status RestoreDeserialized(const std::vector& serialized_model, + Environment* env); + private: enum TensorMemoryType { STRONG_SHAPE = 0, BUFFER = 1, VARIABLE = 2 }; + friend flatbuffers::Offset Encode( + const InferenceContext& inference, + flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference); + void CopyInAndOutIds(const GraphFloat32& graph); absl::Status ConvertOperations(const DeviceInfo& device_info, const GraphFloat32& graph, ModelHints hints); @@ -165,6 +177,32 @@ class InferenceContext { void SetNext(ValueId id) { next_ = id; } DummyTensor Get(ValueId id) { return reservations_[id]; } + std::vector> GetTensorDescs() const { + std::vector> result; + for (auto& v : reservations_) { + TensorDescriptor desc = v.second.descriptor; + desc.shape.b = v.second.shape.b; + desc.shape.h = v.second.shape.h; + desc.shape.w = v.second.shape.w; + desc.shape.d = 1; + desc.shape.c = v.second.shape.c; + result.push_back({v.first, desc}); + } + return result; + } + + void Add(const std::vector>& tensors) { + for (auto& v : tensors) { + DummyTensor dummy; + dummy.descriptor = v.second; + dummy.shape.b = v.second.shape.b; + dummy.shape.h = v.second.shape.h; + dummy.shape.w = v.second.shape.w; + dummy.shape.c = v.second.shape.c; + Add(v.first, dummy); + } + } + private: absl::flat_hash_map reservations_; ValueId next_; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 7bce013a895..d7e7c7dd498 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -651,6 +651,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:device_info", "//tensorflow/lite/delegates/gpu/cl:precision", "//tensorflow/lite/delegates/gpu/cl:program_cache", + "//tensorflow/lite/delegates/gpu/cl:serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/cl:tensor", "//tensorflow/lite/delegates/gpu/cl:tensor_type", "//tensorflow/lite/delegates/gpu/common:access_type", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index 025de5a7c7f..b39f03af846 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -223,7 +223,8 @@ absl::Status GPUOperation::UpdateParams() { return absl::OkStatus(); } -absl::Status GPUOperation::Compile(const CreationContext& creation_context) { +absl::Status GPUOperation::AssembleCode(const DeviceInfo& device_info, + CLContext* context) { if (elementwise_) { auto src_desc = absl::make_unique(definition_.src_tensors[0]); @@ -241,28 +242,35 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) { dst_tensors_names_.insert(dst_tensors_names_.begin(), "dst_tensor"); args_.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc)); - std::string code = - GetElementWiseCode(definition_, check_src_channels_size_); elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_; - RETURN_IF_ERROR(args_.AllocateObjects(creation_context.context)); + code_ = GetElementWiseCode(definition_, check_src_channels_size_); + RETURN_IF_ERROR(args_.AllocateObjects(context)); RETURN_IF_ERROR(args_.TransformToCLCode( - creation_context.device->info_, - {{dst_tensors_names_[0], elementwise_code_}}, &code)); - RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( - code, "main_function", *creation_context.context, - *creation_context.device, &kernel_)); + device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_)); } else { - RETURN_IF_ERROR(args_.AllocateObjects(creation_context.context)); + RETURN_IF_ERROR(args_.AllocateObjects(context)); RETURN_IF_ERROR(args_.TransformToCLCode( - creation_context.device->info_, - {{dst_tensors_names_[0], elementwise_code_}}, &code_)); - RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( - code_, "main_function", compiler_options_, *creation_context.context, - *creation_context.device, &kernel_)); + device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_)); } + return absl::OkStatus(); +} + +absl::Status GPUOperation::Compile(const CreationContext& creation_context) { + RETURN_IF_ERROR( + AssembleCode(creation_context.GetDeviceInfo(), creation_context.context)); + RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( + code_, "main_function", compiler_options_, *creation_context.context, + *creation_context.device, &kernel_)); return PostCompileCheck(creation_context.device->info_, kernel_.info_); } +absl::Status GPUOperation::CompileDeserialized( + const CreationContext& creation_context) { + return creation_context.cache->GetOrCreateCLKernel( + code_, "main_function", compiler_options_, *creation_context.context, + *creation_context.device, &kernel_); +} + void GPUOperation::GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const { @@ -329,7 +337,7 @@ int3 GPUOperation::GetGridSize() const { const int grid_z = 1; return int3(grid_x, grid_y, grid_z); } - return int3(0, 0, 0); + return grid_size_; } void GPUOperation::AddUniquePostfix(const std::string& unique_postfix) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h index fe41e78f93c..57d8690c54e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" @@ -129,8 +130,12 @@ class GPUOperation { absl::Status Tune(const TuningParameters& params); + absl::Status AssembleCode(const DeviceInfo& device_info, CLContext* context); + absl::Status Compile(const CreationContext& creation_context); + absl::Status CompileDeserialized(const CreationContext& creation_context); + virtual absl::Status PostCompileCheck(const DeviceInfo& device_info, const KernelInfo& kernel_info) { return absl::OkStatus(); @@ -164,6 +169,11 @@ class GPUOperation { bool check_src_channels_size_ = false; protected: + friend flatbuffers::Offset Encode( + const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder); + friend absl::Status Decode(CLContext* context, + const data::GPUOperation* fb_op, GPUOperation* op); + virtual absl::Status BindArguments(ArgumentsBinder* args) { return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc new file mode 100644 index 00000000000..3b52fc40bdf --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc @@ -0,0 +1,1049 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/serialization.h" + +#include + +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +data::AccessType ToFB(AccessType type) { + switch (type) { + case AccessType::READ: + return data::AccessType::READ; + case AccessType::WRITE: + return data::AccessType::WRITE; + case AccessType::READ_WRITE: + return data::AccessType::READ_WRITE; + default: + return data::AccessType::READ_WRITE; + } +} + +data::DataType ToFB(DataType type) { + switch (type) { + case DataType::FLOAT16: + return data::DataType::FLOAT16; + case DataType::FLOAT32: + return data::DataType::FLOAT32; + default: + return data::DataType::UNKNOWN; + } +} + +data::MemoryType ToFB(MemoryType type) { + switch (type) { + case MemoryType::CONSTANT: + return data::MemoryType::CONSTANT; + case MemoryType::GLOBAL: + return data::MemoryType::GLOBAL; + case MemoryType::LOCAL: + return data::MemoryType::LOCAL; + } +} + +data::LinearStorageType ToFB(LinearStorageType type) { + switch (type) { + case LinearStorageType::BUFFER: + return data::LinearStorageType::BUFFER; + case LinearStorageType::TEXTURE_2D: + return data::LinearStorageType::TEXTURE_2D; + } +} + +data::TensorStorageType ToFB(TensorStorageType type) { + switch (type) { + case TensorStorageType::BUFFER: + return data::TensorStorageType::BUFFER; + case TensorStorageType::IMAGE_BUFFER: + return data::TensorStorageType::IMAGE_BUFFER; + case TensorStorageType::TEXTURE_2D: + return data::TensorStorageType::TEXTURE_2D; + case TensorStorageType::TEXTURE_ARRAY: + return data::TensorStorageType::TEXTURE_ARRAY; + case TensorStorageType::TEXTURE_3D: + return data::TensorStorageType::TEXTURE_3D; + case TensorStorageType::SINGLE_TEXTURE_2D: + return data::TensorStorageType::SINGLE_TEXTURE_2D; + case TensorStorageType::UNKNOWN: + return data::TensorStorageType::UNKNOWN; + } +} + +data::Layout ToFB(Layout type) { + switch (type) { + case Layout::HWC: + return data::Layout::HWC; + case Layout::BHWC: + return data::Layout::BHWC; + case Layout::HWDC: + return data::Layout::HWDC; + case Layout::BHWDC: + return data::Layout::BHWDC; + default: + return data::Layout::UNKNOWN; + } +} + +data::CalculationsPrecision ToFB(CalculationsPrecision type) { + switch (type) { + case CalculationsPrecision::F32: + return data::CalculationsPrecision::F32; + case CalculationsPrecision::F32_F16: + return data::CalculationsPrecision::F32_F16; + case CalculationsPrecision::F16: + return data::CalculationsPrecision::F16; + } +} + +data::TensorToGrid ToFB(TensorToGrid type) { + switch (type) { + case TensorToGrid::kCustom: + return data::TensorToGrid::CUSTOM; + case TensorToGrid::kWBToX_HDToY_SToZ: + return data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z; + case TensorToGrid::kWBToX_HDToY_ZIs1: + return data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1; + case TensorToGrid::kWBToX_HToY_DToZ: + return data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z; + case TensorToGrid::kBToX_YIs1_ZIs1: + return data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1; + } +} + +data::CompilerOptions ToFB(CompilerOptions type) { + switch (type) { + case CompilerOptions::ADRENO_FULL_SIMD_LINE: + return data::CompilerOptions::ADRENO_FULL_SIMD_LINE; + case CompilerOptions::ADRENO_MORE_WAVES: + return data::CompilerOptions::ADRENO_MORE_WAVES; + case CompilerOptions::POWERVR_FP16: + return data::CompilerOptions::POWERVR_FP16; + case CompilerOptions::CL_OPT_DISABLE: + return data::CompilerOptions::CL_OPT_DISABLE; + case CompilerOptions::CL_2_0: + return data::CompilerOptions::CL_2_0; + case CompilerOptions::CL_3_0: + return data::CompilerOptions::CL_3_0; + } +} + +DataType ToEnum(data::DataType type) { + switch (type) { + case data::DataType::FLOAT16: + return DataType::FLOAT16; + case data::DataType::FLOAT32: + return DataType::FLOAT32; + default: + return DataType::UNKNOWN; + } +} + +AccessType ToEnum(data::AccessType type) { + switch (type) { + case data::AccessType::READ: + return AccessType::READ; + case data::AccessType::WRITE: + return AccessType::WRITE; + case data::AccessType::READ_WRITE: + return AccessType::READ_WRITE; + } +} + +MemoryType ToEnum(data::MemoryType type) { + switch (type) { + case data::MemoryType::CONSTANT: + return MemoryType::CONSTANT; + case data::MemoryType::GLOBAL: + return MemoryType::GLOBAL; + case data::MemoryType::LOCAL: + return MemoryType::LOCAL; + } +} + +LinearStorageType ToEnum(data::LinearStorageType type) { + switch (type) { + case data::LinearStorageType::BUFFER: + return LinearStorageType::BUFFER; + case data::LinearStorageType::TEXTURE_2D: + return LinearStorageType::TEXTURE_2D; + } +} + +TensorStorageType ToEnum(data::TensorStorageType type) { + switch (type) { + case data::TensorStorageType::BUFFER: + return TensorStorageType::BUFFER; + case data::TensorStorageType::IMAGE_BUFFER: + return TensorStorageType::IMAGE_BUFFER; + case data::TensorStorageType::TEXTURE_2D: + return TensorStorageType::TEXTURE_2D; + case data::TensorStorageType::TEXTURE_ARRAY: + return TensorStorageType::TEXTURE_ARRAY; + case data::TensorStorageType::TEXTURE_3D: + return TensorStorageType::TEXTURE_3D; + case data::TensorStorageType::SINGLE_TEXTURE_2D: + return TensorStorageType::SINGLE_TEXTURE_2D; + case data::TensorStorageType::UNKNOWN: + return TensorStorageType::UNKNOWN; + } +} + +Layout ToEnum(data::Layout type) { + switch (type) { + case data::Layout::HWC: + return Layout::HWC; + case data::Layout::BHWC: + return Layout::BHWC; + case data::Layout::HWDC: + return Layout::HWDC; + case data::Layout::BHWDC: + return Layout::BHWDC; + default: + return Layout::UNKNOWN; + } +} + +CalculationsPrecision ToEnum(data::CalculationsPrecision type) { + switch (type) { + case data::CalculationsPrecision::F32: + return CalculationsPrecision::F32; + case data::CalculationsPrecision::F32_F16: + return CalculationsPrecision::F32_F16; + case data::CalculationsPrecision::F16: + return CalculationsPrecision::F16; + } +} + +TensorToGrid ToEnum(data::TensorToGrid type) { + switch (type) { + case data::TensorToGrid::CUSTOM: + return TensorToGrid::kCustom; + case data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z: + return TensorToGrid::kWBToX_HDToY_SToZ; + case data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1: + return TensorToGrid::kWBToX_HDToY_ZIs1; + case data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z: + return TensorToGrid::kWBToX_HToY_DToZ; + case data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1: + return TensorToGrid::kBToX_YIs1_ZIs1; + } +} + +CompilerOptions ToEnum(data::CompilerOptions type) { + switch (type) { + case data::CompilerOptions::ADRENO_FULL_SIMD_LINE: + return CompilerOptions::ADRENO_FULL_SIMD_LINE; + case data::CompilerOptions::ADRENO_MORE_WAVES: + return CompilerOptions::ADRENO_MORE_WAVES; + case data::CompilerOptions::POWERVR_FP16: + return CompilerOptions::POWERVR_FP16; + case data::CompilerOptions::CL_OPT_DISABLE: + return CompilerOptions::CL_OPT_DISABLE; + case data::CompilerOptions::CL_2_0: + return CompilerOptions::CL_2_0; + case data::CompilerOptions::CL_3_0: + return CompilerOptions::CL_3_0; + } +} + +} // namespace + +flatbuffers::Offset Encode( + const int2& v, flatbuffers::FlatBufferBuilder* builder) { + data::Int2Builder int2_builder(*builder); + int2_builder.add_x(v.x); + int2_builder.add_y(v.y); + return int2_builder.Finish(); +} + +flatbuffers::Offset Encode( + const int3& v, flatbuffers::FlatBufferBuilder* builder) { + data::Int3Builder int3_builder(*builder); + int3_builder.add_x(v.x); + int3_builder.add_y(v.y); + int3_builder.add_z(v.z); + return int3_builder.Finish(); +} + +flatbuffers::Offset Encode( + const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + std::vector> state_vars_fb; + for (auto& v0 : desc.state_vars_) { + auto key_fb = builder->CreateString(v0.first); + auto value_fb = builder->CreateString(v0.second); + data::StateVariableBuilder state_builder(*builder); + state_builder.add_key(key_fb); + state_builder.add_value(value_fb); + state_vars_fb.push_back(state_builder.Finish()); + } + auto state_vars_fb_vec = builder->CreateVector(state_vars_fb); + data::GPUObjectDescriptorBuilder obj_builder(*builder); + obj_builder.add_state_vars(state_vars_fb_vec); + obj_builder.add_access_type(ToFB(desc.access_type_)); + return obj_builder.Finish(); +} + +void Decode(const data::GPUObjectDescriptor* fb_obj, GPUObjectDescriptor* obj) { + obj->access_type_ = ToEnum(fb_obj->access_type()); + for (auto state_fb : *fb_obj->state_vars()) { + std::string key(state_fb->key()->c_str(), state_fb->key()->size()); + std::string value(state_fb->value()->c_str(), state_fb->value()->size()); + obj->state_vars_[key] = value; + } +} + +flatbuffers::Offset Encode( + const BufferDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + std::vector> attributes_fb; + for (auto& attr : desc.attributes) { + attributes_fb.push_back(builder->CreateString(attr)); + } + auto attributes_fb_vec = builder->CreateVector(attributes_fb); + auto data_fb = builder->CreateVector(desc.data); + data::BufferDescriptorBuilder buf_builder(*builder); + buf_builder.add_base_obj(obj_fb); + buf_builder.add_element_type(ToFB(desc.element_type)); + buf_builder.add_element_size(desc.element_size); + buf_builder.add_memory_type(ToFB(desc.memory_type)); + buf_builder.add_attributes(attributes_fb_vec); + buf_builder.add_size(desc.size); + buf_builder.add_data(data_fb); + return buf_builder.Finish(); +} + +void Decode(const data::BufferDescriptor* fb_desc, BufferDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->element_size = fb_desc->element_size(); + desc->memory_type = ToEnum(fb_desc->memory_type()); + for (auto attr_fb : *fb_desc->attributes()) { + std::string attr(attr_fb->c_str(), attr_fb->size()); + desc->attributes.push_back(attr); + } + desc->size = fb_desc->size(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const Texture2DDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + auto data_fb = builder->CreateVector(desc.data); + auto size_fb = Encode(desc.size, builder); + data::Texture2DDescriptorBuilder tex_builder(*builder); + tex_builder.add_base_obj(obj_fb); + tex_builder.add_element_type(ToFB(desc.element_type)); + tex_builder.add_normalized(desc.normalized); + tex_builder.add_normalized_type(ToFB(desc.normalized_type)); + tex_builder.add_size(size_fb); + tex_builder.add_data(data_fb); + return tex_builder.Finish(); +} + +void Decode(const data::Texture2DDescriptor* fb_desc, + Texture2DDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->normalized = fb_desc->normalized(); + desc->normalized_type = ToEnum(fb_desc->normalized_type()); + desc->size.x = fb_desc->size()->x(); + desc->size.y = fb_desc->size()->y(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const TensorLinearDescriptor& desc, + flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + auto data_fb = builder->CreateVector(desc.data); + data::TensorLinearDescriptorBuilder tensor_builder(*builder); + tensor_builder.add_base_obj(obj_fb); + tensor_builder.add_element_type(ToFB(desc.element_type)); + tensor_builder.add_storage_type(ToFB(desc.storage_type)); + tensor_builder.add_memory_type(ToFB(desc.memory_type)); + tensor_builder.add_size(desc.size); + tensor_builder.add_data(data_fb); + return tensor_builder.Finish(); +} + +void Decode(const data::TensorLinearDescriptor* fb_desc, + TensorLinearDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->element_type = ToEnum(fb_desc->element_type()); + desc->storage_type = ToEnum(fb_desc->storage_type()); + desc->memory_type = ToEnum(fb_desc->memory_type()); + desc->size = fb_desc->size(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const TensorDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) { + auto obj_fb = + Encode(*static_cast(&desc), builder); + + data::BHWDCBuilder shape_builder(*builder); + shape_builder.add_b(desc.shape.b); + shape_builder.add_h(desc.shape.h); + shape_builder.add_w(desc.shape.w); + shape_builder.add_d(desc.shape.d); + shape_builder.add_c(desc.shape.c); + auto shape_fb = shape_builder.Finish(); + + auto data_fb = builder->CreateVector(desc.data); + data::TensorDescriptorBuilder tensor_builder(*builder); + tensor_builder.add_base_obj(obj_fb); + tensor_builder.add_data_type(ToFB(desc.data_type)); + tensor_builder.add_storage_type(ToFB(desc.storage_type)); + tensor_builder.add_layout(ToFB(desc.layout)); + tensor_builder.add_shape(shape_fb); + tensor_builder.add_data(data_fb); + return tensor_builder.Finish(); +} + +void Decode(const data::TensorDescriptor* fb_desc, TensorDescriptor* desc) { + Decode(fb_desc->base_obj(), desc); + desc->data_type = ToEnum(fb_desc->data_type()); + desc->storage_type = ToEnum(fb_desc->storage_type()); + desc->layout = ToEnum(fb_desc->layout()); + desc->shape.b = fb_desc->shape()->b(); + desc->shape.h = fb_desc->shape()->h(); + desc->shape.w = fb_desc->shape()->w(); + desc->shape.d = fb_desc->shape()->d(); + desc->shape.c = fb_desc->shape()->c(); + desc->data = + std::vector(fb_desc->data()->data(), + fb_desc->data()->data() + fb_desc->data()->size()); +} + +flatbuffers::Offset Encode( + const OperationDef& def, flatbuffers::FlatBufferBuilder* builder) { + std::vector> src_tensors_fb; + for (auto& desc : def.src_tensors) { + auto desc_fb = Encode(desc, builder); + src_tensors_fb.push_back(desc_fb); + } + + std::vector> dst_tensors_fb; + for (auto& desc : def.dst_tensors) { + auto desc_fb = Encode(desc, builder); + dst_tensors_fb.push_back(desc_fb); + } + + auto src_tensors_fb_vec = builder->CreateVector(src_tensors_fb); + auto dst_tensors_fb_vec = builder->CreateVector(dst_tensors_fb); + + data::OperationDefBuilder def_builder(*builder); + def_builder.add_precision(ToFB(def.precision)); + def_builder.add_src_tensors(src_tensors_fb_vec); + def_builder.add_dst_tensors(dst_tensors_fb_vec); + return def_builder.Finish(); +} + +void Decode(const data::OperationDef* fb_def, OperationDef* def) { + for (auto src_fb : *fb_def->src_tensors()) { + TensorDescriptor desc; + Decode(src_fb, &desc); + def->src_tensors.push_back(std::move(desc)); + } + for (auto dst_fb : *fb_def->dst_tensors()) { + TensorDescriptor desc; + Decode(dst_fb, &desc); + def->dst_tensors.push_back(std::move(desc)); + } + def->precision = ToEnum(fb_def->precision()); +} + +flatbuffers::Offset Encode( + const TensorDescriptor& desc, const ValueId& id, + flatbuffers::FlatBufferBuilder* builder) { + auto desc_fb = Encode(desc, builder); + data::TensorDescWithIdBuilder desc_builder(*builder); + desc_builder.add_desc(desc_fb); + desc_builder.add_id(id); + return desc_builder.Finish(); +} + +void Decode(const data::TensorDescWithId* fb_desc, TensorDescriptor* desc, + ValueId* id) { + Decode(fb_desc->desc(), desc); + *id = fb_desc->id(); +} + +absl::Status Decode(CLContext* context, const data::Arguments* fb_args, + Arguments* args) { + args->shared_int4s_data_ = std::vector( + fb_args->shared_int4s()->data(), + fb_args->shared_int4s()->data() + fb_args->shared_int4s()->size()); + + args->shared_float4s_data_ = std::vector( + fb_args->shared_float4s()->data(), + fb_args->shared_float4s()->data() + fb_args->shared_float4s()->size()); + + std::vector tmp = std::vector( + fb_args->shared_half4s()->data(), + fb_args->shared_half4s()->data() + fb_args->shared_half4s()->size()); + + args->shared_half4s_data_.resize(tmp.size()); + for (int i = 0; i < tmp.size(); ++i) { + args->shared_half4s_data_[i] = tmp[i]; + } + + args->int_values_.clear(); + for (auto int_values_fb : *fb_args->int_values()) { + Arguments::IntValue value; + value.value = int_values_fb->value(); + value.offset = int_values_fb->offset(); + value.active = int_values_fb->active(); + std::string name(int_values_fb->name()->c_str(), + int_values_fb->name()->size()); + args->int_values_[name] = value; + } + + args->float_values_.clear(); + for (auto float_values_fb : *fb_args->float_values()) { + Arguments::FloatValue value; + value.value = float_values_fb->value(); + value.offset = float_values_fb->offset(); + value.active = float_values_fb->active(); + std::string name(float_values_fb->name()->c_str(), + float_values_fb->name()->size()); + args->float_values_[name] = value; + } + + args->half_values_.clear(); + for (auto half_values_fb : *fb_args->half_values()) { + Arguments::HalfValue value; + value.value = half_values_fb->value(); + value.offset = half_values_fb->offset(); + value.active = half_values_fb->active(); + value.store_as_f32 = half_values_fb->store_as_f32(); + std::string name(half_values_fb->name()->c_str(), + half_values_fb->name()->size()); + args->half_values_[name] = value; + } + + for (auto buffer_pair_fb : *fb_args->buffer_objects()) { + std::string key(buffer_pair_fb->key()->c_str(), + buffer_pair_fb->key()->size()); + BufferDescriptor desc; + Decode(buffer_pair_fb->value(), &desc); + args->AddObject(key, absl::make_unique(std::move(desc))); + } + + for (auto texture_pair_fb : *fb_args->texture2d_objects()) { + std::string key(texture_pair_fb->key()->c_str(), + texture_pair_fb->key()->size()); + Texture2DDescriptor desc; + Decode(texture_pair_fb->value(), &desc); + args->AddObject(key, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_linear_objects()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorLinearDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + args->AddObject(key, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_objects()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + args->AddObject(key, absl::make_unique(std::move(desc))); + } + + for (auto buffer_pair_fb : *fb_args->buffer_refs()) { + std::string key(buffer_pair_fb->key()->c_str(), + buffer_pair_fb->key()->size()); + BufferDescriptor desc; + Decode(buffer_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto texture_pair_fb : *fb_args->texture2d_refs()) { + std::string key(texture_pair_fb->key()->c_str(), + texture_pair_fb->key()->size()); + Texture2DDescriptor desc; + Decode(texture_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_linear_refs()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorLinearDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef( + key, access_type, + absl::make_unique(std::move(desc))); + } + + for (auto tensor_pair_fb : *fb_args->tensor_refs()) { + std::string key(tensor_pair_fb->key()->c_str(), + tensor_pair_fb->key()->size()); + TensorDescriptor desc; + Decode(tensor_pair_fb->value(), &desc); + auto access_type = desc.GetAccess(); + args->AddObjectRef(key, access_type, + absl::make_unique(std::move(desc))); + } + + RETURN_IF_ERROR(args->AllocateObjects(context)); + RETURN_IF_ERROR(args->AddObjectArgs()); + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const Arguments& args, flatbuffers::FlatBufferBuilder* builder) { + std::vector> int_values_fb; + for (auto& value : args.int_values_) { + auto name_fb = builder->CreateString(value.first); + data::IntValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + int_values_fb.push_back(value_builder.Finish()); + } + + std::vector> float_values_fb; + for (auto& value : args.float_values_) { + auto name_fb = builder->CreateString(value.first); + data::FloatValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + float_values_fb.push_back(value_builder.Finish()); + } + + std::vector> half_values_fb; + for (auto& value : args.half_values_) { + auto name_fb = builder->CreateString(value.first); + data::HalfValueBuilder value_builder(*builder); + value_builder.add_name(name_fb); + value_builder.add_value(value.second.value); + value_builder.add_offset(value.second.offset); + value_builder.add_active(value.second.active); + value_builder.add_store_as_f32(value.second.store_as_f32); + half_values_fb.push_back(value_builder.Finish()); + } + + std::vector> + buffer_objs_fb; + for (auto& value : args.objects_) { + const auto* buffer_desc = + dynamic_cast(value.second.descriptor.get()); + if (!buffer_desc) continue; + auto desc_fb = Encode(*buffer_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::BufferDescriptorMapValueBuilder buf_map_builder(*builder); + buf_map_builder.add_key(key_fb); + buf_map_builder.add_value(desc_fb); + buffer_objs_fb.push_back(buf_map_builder.Finish()); + } + std::vector> + texture2d_objs_fb; + for (auto& value : args.objects_) { + const auto* texture_desc = + dynamic_cast(value.second.descriptor.get()); + if (!texture_desc) continue; + auto desc_fb = Encode(*texture_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::Texture2DDescriptorMapValueBuilder tex_map_builder(*builder); + tex_map_builder.add_key(key_fb); + tex_map_builder.add_value(desc_fb); + texture2d_objs_fb.push_back(tex_map_builder.Finish()); + } + std::vector> + tensor_linear_objs_fb; + for (auto& value : args.objects_) { + const auto* tensor_desc = dynamic_cast( + value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorLinearDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_linear_objs_fb.push_back(ten_map_builder.Finish()); + } + std::vector> + tensor_objs_fb; + for (auto& value : args.objects_) { + const auto* tensor_desc = + dynamic_cast(value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_objs_fb.push_back(ten_map_builder.Finish()); + } + + std::vector> + buffer_refs_fb; + for (auto& value : args.object_refs_) { + const auto* buffer_desc = + dynamic_cast(value.second.descriptor.get()); + if (!buffer_desc) continue; + auto desc_fb = Encode(*buffer_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::BufferDescriptorMapValueBuilder buf_map_builder(*builder); + buf_map_builder.add_key(key_fb); + buf_map_builder.add_value(desc_fb); + buffer_refs_fb.push_back(buf_map_builder.Finish()); + } + std::vector> + texture2d_refs_fb; + for (auto& value : args.object_refs_) { + const auto* texture_desc = + dynamic_cast(value.second.descriptor.get()); + if (!texture_desc) continue; + auto desc_fb = Encode(*texture_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::Texture2DDescriptorMapValueBuilder tex_map_builder(*builder); + tex_map_builder.add_key(key_fb); + tex_map_builder.add_value(desc_fb); + texture2d_refs_fb.push_back(tex_map_builder.Finish()); + } + std::vector> + tensor_linear_refs_fb; + for (auto& value : args.object_refs_) { + const auto* tensor_desc = dynamic_cast( + value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorLinearDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_linear_refs_fb.push_back(ten_map_builder.Finish()); + } + std::vector> + tensor_refs_fb; + for (auto& value : args.object_refs_) { + const auto* tensor_desc = + dynamic_cast(value.second.descriptor.get()); + if (!tensor_desc) continue; + auto desc_fb = Encode(*tensor_desc, builder); + auto key_fb = builder->CreateString(value.first); + data::TensorDescriptorMapValueBuilder ten_map_builder(*builder); + ten_map_builder.add_key(key_fb); + ten_map_builder.add_value(desc_fb); + tensor_refs_fb.push_back(ten_map_builder.Finish()); + } + + auto shared_int4s_data_fb = builder->CreateVector(args.shared_int4s_data_); + auto shared_float4s_data_fb = + builder->CreateVector(args.shared_float4s_data_); + std::vector tmp(args.shared_half4s_data_.size()); + for (int i = 0; i < tmp.size(); ++i) { + tmp[i] = args.shared_half4s_data_[i]; + } + auto shared_half4s_data_fb = builder->CreateVector(tmp); + auto int_values_fb_vec = builder->CreateVector(int_values_fb); + auto float_values_fb_vec = builder->CreateVector(float_values_fb); + auto half_values_fb_vec = builder->CreateVector(half_values_fb); + auto buffer_objs_fb_vec = builder->CreateVector(buffer_objs_fb); + auto texture2d_objs_fb_vec = builder->CreateVector(texture2d_objs_fb); + auto tensor_linear_objs_fb_vec = builder->CreateVector(tensor_linear_objs_fb); + auto tensor_objs_fb_vec = builder->CreateVector(tensor_objs_fb); + auto buffer_refs_fb_vec = builder->CreateVector(buffer_refs_fb); + auto texture2d_refs_fb_vec = builder->CreateVector(texture2d_refs_fb); + auto tensor_linear_refs_fb_vec = builder->CreateVector(tensor_linear_refs_fb); + auto tensor_refs_fb_vec = builder->CreateVector(tensor_refs_fb); + data::ArgumentsBuilder arguments_builder(*builder); + arguments_builder.add_shared_int4s(shared_int4s_data_fb); + arguments_builder.add_shared_float4s(shared_float4s_data_fb); + arguments_builder.add_shared_half4s(shared_half4s_data_fb); + arguments_builder.add_int_values(int_values_fb_vec); + arguments_builder.add_float_values(float_values_fb_vec); + arguments_builder.add_half_values(half_values_fb_vec); + arguments_builder.add_buffer_objects(buffer_objs_fb_vec); + arguments_builder.add_texture2d_objects(texture2d_objs_fb_vec); + arguments_builder.add_tensor_linear_objects(tensor_linear_objs_fb_vec); + arguments_builder.add_tensor_objects(tensor_objs_fb_vec); + arguments_builder.add_buffer_refs(buffer_refs_fb_vec); + arguments_builder.add_texture2d_refs(texture2d_refs_fb_vec); + arguments_builder.add_tensor_linear_refs(tensor_linear_refs_fb_vec); + arguments_builder.add_tensor_refs(tensor_refs_fb_vec); + return arguments_builder.Finish(); +} + +absl::Status Decode(CLContext* context, const data::GPUOperation* fb_op, + GPUOperation* op) { + RETURN_IF_ERROR(Decode(context, fb_op->arguments(), &op->args_)); + op->code_ = std::string(fb_op->code()->c_str(), fb_op->code()->size()); + op->work_group_size_.x = fb_op->work_group_size()->x(); + op->work_group_size_.y = fb_op->work_group_size()->y(); + op->work_group_size_.z = fb_op->work_group_size()->z(); + for (auto option_fb : *fb_op->compiler_options()) { + op->compiler_options_.push_back(ToEnum(option_fb->option())); + } + op->tensor_to_grid_ = ToEnum(fb_op->tensor_to_grid()); + op->elementwise_ = fb_op->elementwise(); + op->linkable_ = fb_op->linkable(); + op->check_src_channels_size_ = fb_op->check_src_channels_size(); + Decode(fb_op->definition(), &op->definition_); + op->grid_dimension_ = fb_op->grid_dimension(); + op->work_group_launch_order_.x = fb_op->work_group_launch_order()->x(); + op->work_group_launch_order_.y = fb_op->work_group_launch_order()->y(); + op->work_group_launch_order_.z = fb_op->work_group_launch_order()->z(); + op->grid_size_.x = fb_op->grid_size()->x(); + op->grid_size_.y = fb_op->grid_size()->y(); + op->grid_size_.z = fb_op->grid_size()->z(); + for (auto name_fb : *fb_op->src_tensors_names()) { + std::string name(name_fb->c_str(), name_fb->size()); + op->src_tensors_names_.push_back(std::move(name)); + } + for (auto name_fb : *fb_op->dst_tensors_names()) { + std::string name(name_fb->c_str(), name_fb->size()); + op->dst_tensors_names_.push_back(std::move(name)); + } + op->work_groups_count_.x = fb_op->work_groups_count()->x(); + op->work_groups_count_.y = fb_op->work_groups_count()->y(); + op->work_groups_count_.z = fb_op->work_groups_count()->z(); + op->linkable_count_ = fb_op->linkable_count(); + op->elementwise_code_ = std::string(fb_op->elementwise_code()->c_str(), + fb_op->elementwise_code()->size()); + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder) { + auto args_fb = Encode(op.args_, builder); + auto code_fb = builder->CreateString(op.code_); + auto work_group_size_fb = Encode(op.work_group_size_, builder); + std::vector> compiler_options_fb; + for (int i = 0; i < op.compiler_options_.size(); ++i) { + data::CompilerOptionBuilder option_builder(*builder); + option_builder.add_option(ToFB(op.compiler_options_[i])); + compiler_options_fb.push_back(option_builder.Finish()); + } + auto compiler_options_fb_vec = builder->CreateVector(compiler_options_fb); + + auto def_fb = Encode(op.definition_, builder); + auto work_group_launch_order_fb = + Encode(op.work_group_launch_order_, builder); + auto grid_size_fb = Encode(op.grid_size_, builder); + auto work_groups_count_fb = Encode(op.work_groups_count_, builder); + + std::vector> src_names_fb; + for (auto& name : op.src_tensors_names_) { + src_names_fb.push_back(builder->CreateString(name)); + } + auto src_names_fb_vec = builder->CreateVector(src_names_fb); + + std::vector> dst_names_fb; + for (auto& name : op.dst_tensors_names_) { + dst_names_fb.push_back(builder->CreateString(name)); + } + auto dst_names_fb_vec = builder->CreateVector(dst_names_fb); + + auto elementwise_code_fb = builder->CreateString(op.elementwise_code_); + + data::GPUOperationBuilder op_builder(*builder); + op_builder.add_arguments(args_fb); + op_builder.add_code(code_fb); + op_builder.add_work_group_size(work_group_size_fb); + op_builder.add_compiler_options(compiler_options_fb_vec); + op_builder.add_tensor_to_grid(ToFB(op.tensor_to_grid_)); + op_builder.add_elementwise(op.elementwise_); + op_builder.add_linkable(op.linkable_); + op_builder.add_check_src_channels_size(op.check_src_channels_size_); + op_builder.add_definition(def_fb); + op_builder.add_grid_dimension(op.grid_dimension_); + op_builder.add_work_group_launch_order(work_group_launch_order_fb); + op_builder.add_grid_size(grid_size_fb); + op_builder.add_src_tensors_names(src_names_fb_vec); + op_builder.add_dst_tensors_names(dst_names_fb_vec); + op_builder.add_work_groups_count(work_groups_count_fb); + op_builder.add_linkable_count(op.linkable_count_); + op_builder.add_elementwise_code(elementwise_code_fb); + return op_builder.Finish(); +} + +flatbuffers::Offset Encode( + const CLNode& node, flatbuffers::FlatBufferBuilder* builder) { + auto op_fb = Encode(*node.operation, builder); + std::vector in_ids(node.inputs.size()); + for (int i = 0; i < in_ids.size(); ++i) { + in_ids[i] = node.inputs[i]; + } + std::vector out_ids(node.outputs.size()); + for (int i = 0; i < out_ids.size(); ++i) { + out_ids[i] = node.outputs[i]; + } + auto in_ids_fb = builder->CreateVector(in_ids); + auto out_ids_fb = builder->CreateVector(out_ids); + auto name_fb = builder->CreateString(node.name); + data::CLNodeBuilder node_builder(*builder); + node_builder.add_gpu_op(op_fb); + node_builder.add_input_ids(in_ids_fb); + node_builder.add_output_ids(out_ids_fb); + node_builder.add_name(name_fb); + return node_builder.Finish(); +} + +absl::Status Decode(CLContext* context, const data::CLNode* fb_node, + CLNode* node) { + GPUOperation op; + RETURN_IF_ERROR(Decode(context, fb_node->gpu_op(), &op)); + node->operation = absl::make_unique(std::move(op)); + for (auto in_fb : *fb_node->input_ids()) { + node->inputs.push_back(in_fb); + } + for (auto out_fb : *fb_node->output_ids()) { + node->outputs.push_back(out_fb); + } + node->name = std::string(fb_node->name()->c_str(), fb_node->name()->size()); + + return absl::OkStatus(); +} + +flatbuffers::Offset Encode( + const InferenceContext& inference, + flatbuffers::FlatBufferBuilder* builder) { + std::vector in_ids(inference.input_ids_.size()); + for (int i = 0; i < in_ids.size(); ++i) { + in_ids[i] = inference.input_ids_[i]; + } + std::vector out_ids(inference.output_ids_.size()); + for (int i = 0; i < out_ids.size(); ++i) { + out_ids[i] = inference.output_ids_[i]; + } + auto in_ids_fb = builder->CreateVector(in_ids); + auto out_ids_fb = builder->CreateVector(out_ids); + + std::vector> nodes_fb; + for (int i = 0; i < inference.nodes_.size(); ++i) { + auto node_fb = Encode(inference.nodes_[i], builder); + nodes_fb.push_back(node_fb); + } + auto nodes_fb_vec = builder->CreateVector(nodes_fb); + + std::vector> tensors_fb; + auto tensors = inference.tensor_reserver_.GetTensorDescs(); + for (auto& tensor : tensors) { + auto tensor_fb = Encode(tensor.second, tensor.first, builder); + tensors_fb.push_back(tensor_fb); + } + auto tensors_fb_vec = builder->CreateVector(tensors_fb); + + std::vector> + variable_ids_and_refs_fb; + for (auto& pair : inference.variable_ids_and_refs_) { + data::PairOfValueIdsBuilder pair_builder(*builder); + pair_builder.add_first(pair.first); + pair_builder.add_second(pair.second); + variable_ids_and_refs_fb.push_back(pair_builder.Finish()); + } + auto variable_ids_and_refs_fb_vec = + builder->CreateVector(variable_ids_and_refs_fb); + + data::InferenceContextBuilder inf_builder(*builder); + inf_builder.add_need_flush(inference.need_flush_); + inf_builder.add_flush_periodically(inference.flush_periodically_); + inf_builder.add_flush_period(inference.flush_period_); + inf_builder.add_need_manual_release(inference.need_manual_release_); + inf_builder.add_precision(ToFB(inference.precision_)); + inf_builder.add_storage_type(ToFB(inference.storage_type_)); + inf_builder.add_nodes(nodes_fb_vec); + inf_builder.add_tensors(tensors_fb_vec); + inf_builder.add_input_ids(in_ids_fb); + inf_builder.add_output_ids(out_ids_fb); + inf_builder.add_variable_ids_and_refs(variable_ids_and_refs_fb_vec); + return inf_builder.Finish(); +} + +absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference) { + inference->need_flush_ = fb_inference->need_flush(); + inference->flush_periodically_ = fb_inference->flush_periodically(); + inference->flush_period_ = fb_inference->flush_period(); + inference->need_manual_release_ = fb_inference->need_manual_release(); + inference->precision_ = ToEnum(fb_inference->precision()); + inference->storage_type_ = ToEnum(fb_inference->storage_type()); + + inference->nodes_.resize(fb_inference->nodes()->size()); + int counter = 0; + for (auto node_fb : *fb_inference->nodes()) { + RETURN_IF_ERROR(Decode(context, node_fb, &inference->nodes_[counter])); + counter++; + } + + std::vector> tensors; + for (auto tensor_fb : *fb_inference->tensors()) { + TensorDescriptor desc; + Decode(tensor_fb->desc(), &desc); + tensors.push_back({tensor_fb->id(), std::move(desc)}); + } + inference->tensor_reserver_.Add(tensors); + for (auto in_fb : *fb_inference->input_ids()) { + inference->input_ids_.push_back(in_fb); + } + for (auto out_fb : *fb_inference->output_ids()) { + inference->output_ids_.push_back(out_fb); + } + + for (auto variable_id : *fb_inference->variable_ids_and_refs()) { + inference->variable_ids_and_refs_[variable_id->first()] = + variable_id->second(); + } + return absl::OkStatus(); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.fbs b/tensorflow/lite/delegates/gpu/cl/serialization.fbs new file mode 100644 index 00000000000..0c0d2241b5a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.fbs @@ -0,0 +1,278 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite.gpu.cl.data; + +table Int4 { + x:int32; + y:int32; + z:int32; + w:int32; +} + +table Int3 { + x:int32; + y:int32; + z:int32; +} + +table Int2 { + x:int32; + y:int32; +} + +table IntValue { + name:string; + value:int32; + active:bool; + offset:uint32; +} + +table FloatValue { + name:string; + value:float; + active:bool; + offset:uint32; +} + +table HalfValue { + name:string; + value:float; + active:bool; + store_as_f32:bool; + offset:uint32; +} + +enum AccessType : byte { + READ = 0, + WRITE = 1, + READ_WRITE = 2, +} + +enum DataType : byte { + UNKNOWN = 0, + FLOAT32 = 1, + FLOAT16 = 2, +} + +enum MemoryType : byte { + GLOBAL = 0, + CONSTANT = 1, + LOCAL = 2, +} + +table StateVariable { + key:string; + value:string; +} + +table GPUObjectDescriptor { + state_vars:[StateVariable]; + access_type:AccessType; +} + +table BufferDescriptor { + base_obj:GPUObjectDescriptor; + element_type:DataType; + element_size:int32; + memory_type:MemoryType; + attributes:[string]; + size:int32; + data:[uint8]; +} + +table Texture2DDescriptor { + base_obj:GPUObjectDescriptor; + element_type:DataType; + normalized:bool; + normalized_type:DataType; + size:Int2; + data:[uint8]; +} + +enum LinearStorageType : byte { + BUFFER = 0, + TEXTURE_2D = 1, +} + +table TensorLinearDescriptor { + base_obj:GPUObjectDescriptor; + storage_type:LinearStorageType; + element_type:DataType; + memory_type:MemoryType; + size:int32; + data:[uint8]; +} + +enum TensorStorageType : byte { + UNKNOWN = 0, + BUFFER = 1, + IMAGE_BUFFER = 2, + TEXTURE_2D = 3, + TEXTURE_3D = 4, + TEXTURE_ARRAY = 5, + SINGLE_TEXTURE_2D = 6, +} + +enum Layout : byte { + UNKNOWN = 0, + HWC = 1, + BHWC = 2, + HWDC = 3, + BHWDC = 4, +} + +table BHWDC { + b:int32; + h:int32; + w:int32; + d:int32; + c:int32; +} + +table TensorDescriptor { + base_obj:GPUObjectDescriptor; + data_type:DataType; + storage_type:TensorStorageType; + layout:Layout; + shape:BHWDC; + data:[uint8]; +} + +table BufferDescriptorMapValue { + key:string; + value:BufferDescriptor; +} + +table Texture2DDescriptorMapValue { + key:string; + value:Texture2DDescriptor; +} + +table TensorLinearDescriptorMapValue { + key:string; + value:TensorLinearDescriptor; +} + +table TensorDescriptorMapValue { + key:string; + value:TensorDescriptor; +} + +table Arguments { + int_values:[IntValue]; + shared_int4s:[int32]; + + float_values:[FloatValue]; + shared_float4s:[float]; + + half_values:[HalfValue]; + shared_half4s:[float]; + + buffer_refs:[BufferDescriptorMapValue]; + texture2d_refs:[Texture2DDescriptorMapValue]; + tensor_linear_refs:[TensorLinearDescriptorMapValue]; + tensor_refs:[TensorDescriptorMapValue]; + + buffer_objects:[BufferDescriptorMapValue]; + texture2d_objects:[Texture2DDescriptorMapValue]; + tensor_linear_objects:[TensorLinearDescriptorMapValue]; + tensor_objects:[TensorDescriptorMapValue]; +} + +enum CalculationsPrecision : byte { + F32 = 0, + F32_F16 = 1, + F16 = 2, +} + +enum TensorToGrid : byte { + CUSTOM = 0, + WB_TO_X_HD_TO_Y_S_TO_Z = 1, + WB_TO_X_HD_TO_Y_Z_IS_1 = 2, + WB_TO_X_H_TO_Y_D_TO_Z = 3, + B_TO_X_Y_IS_1_Z_IS_1 = 4, +} + +enum CompilerOptions : byte { + ADRENO_FULL_SIMD_LINE = 0, + ADRENO_MORE_WAVES = 1, + POWERVR_FP16 = 2, + CL_OPT_DISABLE = 3, + CL_2_0 = 4, + CL_3_0 = 5, +} + +table OperationDef { + precision:CalculationsPrecision; + src_tensors:[TensorDescriptor]; + dst_tensors:[TensorDescriptor]; +} + +table CompilerOption { + option:CompilerOptions; +} + +table GPUOperation { + arguments:Arguments; + code:string; + work_group_size:Int3; + compiler_options:[CompilerOption]; + tensor_to_grid:TensorToGrid; + elementwise:bool; + linkable:bool; + check_src_channels_size:bool; + definition:OperationDef; + grid_dimension:int32; + work_group_launch_order:Int3; + grid_size:Int3; + src_tensors_names:[string]; + dst_tensors_names:[string]; + work_groups_count:Int3; + linkable_count:int32; + elementwise_code:string; +} + +table TensorDescWithId { + desc:TensorDescriptor; + id:int32; +} + +table CLNode { + gpu_op:GPUOperation; + input_ids:[int32]; + output_ids:[int32]; + name:string; +} + +table PairOfValueIds { + first:int32; + second:int32; +} + +table InferenceContext { + need_flush:bool; + flush_periodically:bool; + flush_period:int32; + need_manual_release:bool; + precision:CalculationsPrecision; + storage_type:TensorStorageType; + nodes:[CLNode]; + tensors:[TensorDescWithId]; + input_ids:[int32]; + variable_ids_and_refs:[PairOfValueIds]; + output_ids:[int32]; +} + +root_type InferenceContext; diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.h b/tensorflow/lite/delegates/gpu/cl/serialization.h new file mode 100644 index 00000000000..1273e62a100 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/serialization.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ + +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { + +class InferenceContext; + +flatbuffers::Offset Encode( + const InferenceContext& inference, flatbuffers::FlatBufferBuilder* builder); + +absl::Status Decode(CLContext* context, + const data::InferenceContext* fb_inference, + InferenceContext* inference); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SERIALIZATION_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index 72c53c5b1ac..c35554b875b 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -605,8 +605,11 @@ absl::Status Tensor::CreateFromDescriptor(const TensorDescriptor& desc, descriptor_.layout = desc.layout; memory_owner_ = true; CLMemory memory; - RETURN_IF_ERROR(AllocateTensorMemory(*context, shape_, descriptor_, - desc.data.data(), &memory)); + uint8_t* data_ptr = desc.data.empty() + ? nullptr + : const_cast(desc.data.data()); + RETURN_IF_ERROR( + AllocateTensorMemory(*context, shape_, descriptor_, data_ptr, &memory)); memory_ = memory.Release(); if (desc.storage_type == TensorStorageType::IMAGE_BUFFER) { RETURN_IF_ERROR(CreateImageBufferFromBuffer( diff --git a/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc index 3e9b614c8c4..be297546709 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/time/time.h" +#include "absl/types/span.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/environment.h" @@ -85,8 +86,18 @@ void CompareCPUGPUResults(tflite::Interpreter* cpu, } } // namespace +absl::Status RunModelSampleWithInternalAPISerializedKernels( + const std::string& model_name, const std::vector& kernel_cache); + +absl::Status RunModelSampleWithInternalAPISerialized( + tflite::Interpreter* cpu, const std::vector& in_refs, + const std::vector& out_refs, + const std::vector& kernel_cache, + const std::vector& serialized_model); + // Run Jet with OpenCL internal API and compares correctness with TFLite CPU -absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { +absl::Status RunModelSampleWithInternalAPI(const std::string& model_name, + std::vector* kernel_cache) { auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); ops::builtin::BuiltinOpResolver op_resolver; @@ -124,6 +135,7 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { return absl::InternalError("Failed to Invoke CPU inference."); } + const auto start = std::chrono::high_resolution_clock::now(); GraphFloat32 graph_cl; RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); @@ -156,6 +168,7 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { options.priority2 = InferencePriority::MIN_MEMORY_USAGE; options.priority3 = InferencePriority::MAX_PRECISION; options.usage = InferenceUsage::SUSTAINED_SPEED; + RETURN_IF_ERROR( inf_env->NewInferenceBuilder(options, std::move(graph_cl), &builder)); @@ -176,6 +189,15 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { // Builds runner. RETURN_IF_ERROR(builder->Build(&runner)); + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Initialization total time - " << (end - start).count() * 1e-6f + << "ms" << std::endl; + + if (kernel_cache) { + *kernel_cache = inf_env->GetSerializedBinaryCache(); + std::cout << "Kernel cache size - " << kernel_cache->size() << std::endl; + } + // Sets the input/output object. for (int i = 0; i < in_refs.size(); ++i) { TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); @@ -198,6 +220,205 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { return absl::OkStatus(); } +absl::Status RunModelSampleWithInternalAPISerializedKernels( + const std::string& model_name, const std::vector& kernel_cache) { + auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); + + ops::builtin::BuiltinOpResolver op_resolver; + InterpreterBuilder tfl_builder(*flatbuffer, op_resolver); + + // CPU. + std::unique_ptr cpu_inference; + tfl_builder(&cpu_inference); + if (!cpu_inference) { + return absl::InternalError("Failed to build CPU inference."); + } + auto status = cpu_inference->AllocateTensors(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to AllocateTensors for CPU inference."); + } + for (int k = 0; k < cpu_inference->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->inputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 input tensors"); + } + } + for (int k = 0; k < cpu_inference->outputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->outputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 output tensors"); + } + } + FillInputTensors(cpu_inference.get()); + status = cpu_inference->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + const auto start = std::chrono::high_resolution_clock::now(); + GraphFloat32 graph_cl; + RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); + + auto inputs = graph_cl.inputs(); + auto outputs = graph_cl.outputs(); + std::vector in_refs(inputs.size()); + std::vector out_refs(outputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + in_refs[i] = inputs[i]->tensor.ref; + } + for (int i = 0; i < outputs.size(); ++i) { + out_refs[i] = outputs[i]->tensor.ref; + } + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + env_options.serialized_binary_cache = + absl::MakeSpan(kernel_cache.data(), kernel_cache.size()); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + InferenceOptions options; + options.priority1 = InferencePriority::MIN_LATENCY; + options.priority2 = InferencePriority::MIN_MEMORY_USAGE; + options.priority3 = InferencePriority::MAX_PRECISION; + options.usage = InferenceUsage::SUSTAINED_SPEED; + + std::vector serialized_model; + RETURN_IF_ERROR(inf_env->BuildSerializedModel(options, std::move(graph_cl), + &serialized_model)); + std::unique_ptr builder; + RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Initialization total time(with kernel cache) - " + << (end - start).count() * 1e-6f << "ms" << std::endl; + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f); + + RETURN_IF_ERROR(RunModelSampleWithInternalAPISerialized( + cpu_inference.get(), in_refs, out_refs, kernel_cache, serialized_model)); + + return absl::OkStatus(); +} + +absl::Status RunModelSampleWithInternalAPISerialized( + tflite::Interpreter* cpu, const std::vector& in_refs, + const std::vector& out_refs, + const std::vector& kernel_cache, + const std::vector& serialized_model) { + FillInputTensors(cpu); + auto status = cpu->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + const auto start = std::chrono::high_resolution_clock::now(); + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + env_options.serialized_binary_cache = + absl::MakeSpan(kernel_cache.data(), kernel_cache.size()); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + std::unique_ptr builder; + RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + const auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Serialized initialization total time - " + << (end - start).count() * 1e-6f << "ms" << std::endl; + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + std::cout << "Comparing results second time:" << std::endl; + + CompareCPUGPUResults(cpu, out_refs, output_tensors, 1e-4f); + + return absl::OkStatus(); +} + } // namespace cl } // namespace gpu } // namespace tflite @@ -214,7 +435,15 @@ int main(int argc, char** argv) { return -1; } - auto run_status = tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1]); + std::vector kernel_cache; + auto run_status = + tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1], &kernel_cache); + if (!run_status.ok()) { + std::cerr << run_status.message(); + return -1; + } + run_status = tflite::gpu::cl::RunModelSampleWithInternalAPISerializedKernels( + argv[1], kernel_cache); if (!run_status.ok()) { std::cerr << run_status.message(); return -1; diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 7125064d7a8..99d915f0ed2 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -150,6 +150,7 @@ cc_library( ":tensor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "//tensorflow/lite/delegates:utils", "//tensorflow/lite:kernel_api", diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index ebb5d628cdc..c200f0926aa 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.h b/tensorflow/lite/delegates/gpu/metal_delegate.h index ea9da126954..e1e2ed7ff58 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.h +++ b/tensorflow/lite/delegates/gpu/metal_delegate.h @@ -47,10 +47,15 @@ typedef struct { bool allow_precision_loss; TFLGpuDelegateWaitType wait_type; // Allows execution of integer quantized models - // TODO(b/169350710): Enable by default. bool enable_quantization; } TFLGpuDelegateOptions; +// Populates TFLGpuDelegateOptions as follows: +// allow_precision_loss = false; +// wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive; +// enable_quantization = true; +TFL_CAPI_EXPORT extern TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault(void); + // Creates a new delegate instance that need to be destroyed with // `TFLDeleteTfLiteGpuDelegate` when delegate is no longer used by TFLite. // When `options` is set to `nullptr`, the following default values are used: diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index e97e89d54c0..1d66afa938f 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -177,10 +177,7 @@ class Delegate { if (options) { options_ = *options; } else { - // Default options. - options_.allow_precision_loss = false; - options_.enable_quantization = false; - options_.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive; + options_ = TFLGpuDelegateOptionsDefault(); } metal_device_ = MTLCreateSystemDefaultDevice(); command_queue_ = [metal_device_ newCommandQueue]; @@ -732,3 +729,12 @@ bool TFLGpuDelegateSetCommandEncoder( metal_delegate->SetCommandEncoder(encoder, control_encoder); return true; } + +TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault() { + TFLGpuDelegateOptions options = { + .allow_precision_loss = false, + .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive, + .enable_quantization = true, + }; + return options; +} diff --git a/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc b/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc index 387ccb21ed3..9254b824dc6 100644 --- a/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc +++ b/tensorflow/lite/delegates/hexagon/java/src/main/native/hexagon_delegate_jni.cc @@ -45,6 +45,7 @@ Java_org_tensorflow_lite_HexagonDelegate_setAdspLibraryPath( std::stringstream path; path << lib_dir_path << ";/system/lib/rfsa/adsp;/system/vendor/lib/rfsa/adsp;/dsp"; + env->ReleaseStringUTFChars(native_lib_path, lib_dir_path); return setenv("ADSP_LIBRARY_PATH", path.str().c_str(), 1 /*override*/) == 0 ? JNI_TRUE : JNI_FALSE; diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc index d5c86acf16f..345fd6da168 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_device_selection_test.cc @@ -270,7 +270,10 @@ class ArgMaxOpModel : public SingleOpModel, public AcceleratedModel { SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, CreateArgMaxOptions(builder_, output_type).Union()); - BuildInterpreter({input_shape, {1}}); + BuildInterpreter({input_shape, {1}}, /*num_threads*/ -1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/false); + ApplyDelegate(); } }; @@ -410,7 +413,8 @@ class AddSubOpsAcceleratedModel : public MultiOpModel, public AcceleratedModel { {add_output, input3_}, {output_}); BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}, /*num_threads=*/-1, allow_fp32_relax_to_fp16, - /*apply_delegate=*/true); + /*apply_delegate=*/false); + ApplyDelegate(); } }; @@ -591,7 +595,8 @@ class HardSwishAddOpsAcceleratedModel : public MultiOpModel, CreateAddOptions(builder_, activation_type).Union(), {input1_, hard_swish_output}, {output_}); BuildInterpreter({GetShape(input1_), GetShape(input2_)}, /*num_threads=*/-1, - allow_fp32_relax_to_fp16, /*apply_delegate=*/true); + allow_fp32_relax_to_fp16, /*apply_delegate=*/false); + ApplyDelegate(); } }; @@ -721,7 +726,8 @@ class QuantizedWeightsConvolutionOpModel : public SingleOpModel, BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}, num_threads, /*allow_fp32_relax_to_fp16=*/false, - /*apply_delegate=*/true); + /*apply_delegate=*/false); + ApplyDelegate(); } void SetInput(std::initializer_list data) { @@ -867,7 +873,11 @@ class LongIdentityModel : public MultiOpModel, public AcceleratedModel { {intermediate_outputs[intermediate_outputs.size() - 1], zero_input_}, {output_}); - BuildInterpreter({GetShape(input_), GetShape(zero_input_)}); + BuildInterpreter({GetShape(input_), GetShape(zero_input_)}, + /*num_threads*/ -1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/false); + ApplyDelegate(); std::vector zero(GetTensorSize(input_), 0.0); PopulateTensor(zero_input_, zero); diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc index f347799b4b8..976876d04e2 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_errno_test.cc @@ -75,7 +75,11 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI { SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, CreateAddOptions(builder_, activation_type).Union()); BuildInterpreter({GetShape(input1_), GetShape(input2_)}, /*num_threads=*/-1, - allow_fp32_relax_to_fp16, /*apply_delegate=*/true); + allow_fp32_relax_to_fp16, /*apply_delegate=*/false); + // We defer applying the 'stateful_delegate_' till now (i.e. via setting + // 'apply_delegate=false' above) so that default TfLite delegates won't be + // applied. + ApplyDelegate(); } }; diff --git a/tensorflow/lite/experimental/acceleration/configuration/configuration.proto b/tensorflow/lite/experimental/acceleration/configuration/configuration.proto index 497bf3cb58c..15ff046cb05 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/configuration.proto +++ b/tensorflow/lite/experimental/acceleration/configuration/configuration.proto @@ -127,6 +127,11 @@ message NNAPISettings { // dynamic dimensions of the model. // By default this is set to false. optional bool allow_dynamic_dimensions = 9; + + // Whether to allow the NNAPI accelerator to optionally use lower-precision + // float16 (16-bit floating point) arithmetic when doing calculations on + // float32 (32-bit floating point). + optional bool allow_fp16_precision_for_fp32 = 10; } // Which GPU backend to select. Default behaviour on Android is to try OpenCL diff --git a/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc index cf99f530d6d..30dda0daa5b 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc +++ b/tensorflow/lite/experimental/acceleration/configuration/nnapi_plugin.cc @@ -96,6 +96,7 @@ class NnapiPlugin : public DelegatePluginInterface { !nnapi_settings->allow_nnapi_cpu_on_android_10_plus(); options_.execution_priority = ConvertExecutionPriority(nnapi_settings->execution_priority()); + options_.allow_fp16 = nnapi_settings->allow_fp16_precision_for_fp32(); } private: diff --git a/tensorflow/lite/experimental/objc/apis/TFLMetalDelegate.h b/tensorflow/lite/experimental/objc/apis/TFLMetalDelegate.h index f73c5dedfe5..026e945d888 100644 --- a/tensorflow/lite/experimental/objc/apis/TFLMetalDelegate.h +++ b/tensorflow/lite/experimental/objc/apis/TFLMetalDelegate.h @@ -57,7 +57,7 @@ typedef NS_ENUM(NSUInteger, TFLMetalDelegateThreadWaitType) { /** * Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default is - * `false`. + * `true`. */ @property(nonatomic, getter=isQuantizationEnabled) BOOL quantizationEnabled; diff --git a/tensorflow/lite/experimental/objc/sources/TFLMetalDelegate.m b/tensorflow/lite/experimental/objc/sources/TFLMetalDelegate.m index 85894e61cd1..e5bdb18967b 100644 --- a/tensorflow/lite/experimental/objc/sources/TFLMetalDelegate.m +++ b/tensorflow/lite/experimental/objc/sources/TFLMetalDelegate.m @@ -29,6 +29,7 @@ NS_ASSUME_NONNULL_BEGIN - (instancetype)init { self = [super init]; if (self != nil) { + _quantizationEnabled = true; _waitType = TFLMetalDelegateThreadWaitTypePassive; } return self; diff --git a/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift b/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift index 4d0060231f6..0eefe353e58 100644 --- a/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift +++ b/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift @@ -62,8 +62,8 @@ extension MetalDelegate { public var waitType: ThreadWaitType = .passive /// Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default - /// is `false`. - public var isQuantizationEnabled = false + /// is `true`. + public var isQuantizationEnabled = true /// Creates a new instance with the default values. public init() {} diff --git a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template index 1e414f1959f..f627554df78 100644 --- a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template +++ b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template @@ -52,8 +52,7 @@ Pod::Spec.new do |s| metal.test_spec 'Tests' do |ts| ts.source_files = swift_dir + 'Tests/{Interpreter,MetalDelegate}Tests.swift' ts.resources = [ - tfl_dir + 'testdata/add.bin', - tfl_dir + 'testdata/add_quantized.bin', + tfl_dir + 'testdata/multi_add.bin', ] end end diff --git a/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift b/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift index 8e8de7c320d..4983d7fc324 100644 --- a/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift +++ b/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift @@ -34,18 +34,26 @@ class MetalDelegateTests: XCTestCase { } func testInitInterpreterWithDelegate() throws { + // If metal device is not available, skip. + if MTLCreateSystemDefaultDevice() == nil { + return + } let metalDelegate = MetalDelegate() - let interpreter = try Interpreter(modelPath: AddQuantizedModel.path, delegates: [metalDelegate]) + let interpreter = try Interpreter(modelPath: MultiAddModel.path, delegates: [metalDelegate]) XCTAssertEqual(interpreter.delegates?.count, 1) XCTAssertNil(interpreter.options) } func testInitInterpreterWithOptionsAndDelegate() throws { + // If metal device is not available, skip. + if MTLCreateSystemDefaultDevice() == nil { + return + } var options = Interpreter.Options() options.threadCount = 1 let metalDelegate = MetalDelegate() let interpreter = try Interpreter( - modelPath: AddQuantizedModel.path, + modelPath: MultiAddModel.path, options: options, delegates: [metalDelegate] ) @@ -91,3 +99,16 @@ class MetalDelegateOptionsTests: XCTestCase { XCTAssertNotEqual(options1, options2) } } + + +/// Values for the `multi_add.bin` model. +enum MultiAddModel { + static let info = (name: "multi_add", extension: "bin") + + static var path: String = { + let bundle = Bundle(for: MetalDelegateTests.self) + guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" } + return path + }() +} + diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 47550be2a21..14d7219f304 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -80,6 +80,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteUnpackParams", "TfLiteReverseSequenceParams", "TfLiteWhileParams", + "TfLiteCumsumParams", nullptr}; } // namespace diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index 22fd564635c..37a3682f473 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -139,7 +139,6 @@ upper_tabs: path: /lite/performance/measurement - title: "Delegates" path: /lite/performance/delegates - status: experimental - title: "GPU delegate" path: /lite/performance/gpu - title: "Advanced GPU" @@ -152,6 +151,9 @@ upper_tabs: - title: "Core ML delegate" path: /lite/performance/coreml_delegate status: experimental + - title: "Implementing a delegate" + path: /lite/performance/implementing_delegate + status: experimental - heading: "Optimize a model" - title: "Overview" diff --git a/tensorflow/lite/g3doc/convert/index.md b/tensorflow/lite/g3doc/convert/index.md index 36774205770..070e801030b 100644 --- a/tensorflow/lite/g3doc/convert/index.md +++ b/tensorflow/lite/g3doc/convert/index.md @@ -74,7 +74,7 @@ import tensorflow as tf # Convert the model converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory -tflite_model = converter.convert(). +tflite_model = converter.convert() # Save the model. with open('model.tflite', 'wb') as f: diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md index 09ce3a12a49..f152ce69f7f 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md @@ -42,6 +42,10 @@ API. ## Run inference in Java +See the +[Object Detection reference app](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android/) +for an example of how to use `ObjectDetector` in an Android app. + ### Step 1: Import Gradle dependency and other settings Copy the `.tflite` model file to the assets directory of the Android module diff --git a/tensorflow/lite/g3doc/models/image_classification/overview.md b/tensorflow/lite/g3doc/models/image_classification/overview.md index 8f401e916b5..a9b1d49e0dc 100644 --- a/tensorflow/lite/g3doc/models/image_classification/overview.md +++ b/tensorflow/lite/g3doc/models/image_classification/overview.md @@ -2,42 +2,29 @@ -Use a pre-trained and optimized model to identify hundreds of classes of -objects, including people, activities, animals, plants, and places. +The task of identifying what an image represents is called _image +classification_. An image classification model is trained to recognize various +classes of images. For example, you may train a model to recognize photos +representing three different types of animals: rabbits, hamsters, and dogs. +TensorFlow Lite provides optimized pre-trained models that you can deploy in +your mobile applications. Learn more about image classification using TensorFlow +[here](https://www.tensorflow.org/tutorials/images/classification). + +The following image shows the output of the image classification model on +Android. + +Screenshot of Android example ## Get started -If you are unfamiliar with the concept of image classification, you should start -by reading What is image -classification? - -To learn how to use image classification in a mobile app, we recommend exploring -our Example applications and -guides. - -If you are using a platform other than Android or iOS, or you are already -familiar with the TensorFlow Lite APIs, you can download our starter image -classification model and the accompanying labels. - -Download -starter model and labels - -Once you have the starter model running on your target device, you can -experiment with different models to find the optimal balance between -performance, accuracy, and model size. For guidance, see -Choose a different model. - -### Example applications and guides - -We have example applications for image classification for both Android and iOS. -For each example, we provide a guide that explains how it works. - -#### Android +If you are new to TensorFlow Lite and are working with Android or iOS, it is +recommended you explore the following example applications that can help you get +started. You can leverage the out-of-box API from TensorFlow Lite Task Library to [integrate image classification models](../../inference_with_metadata/task_library/image_classifier) in just a few lines of code. You can also -[build your own custom inference pipleline](../../inference_with_metadata/lite_support) +[build your own custom inference pipeline](../../inference_with_metadata/lite_support) using the TensorFlow Lite Support Library. The Android example below demonstrates the implementation for both methods as @@ -49,39 +36,34 @@ respectively. View Android example -Read the -[Android example guide](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/EXPLORE_THE_CODE.md) -to learn how the app works. - -#### iOS - View iOS example -Read the -[iOS example guide](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/ios/EXPLORE_THE_CODE.md) -to learn how the app works. +If you are using a platform other than Android/iOS, or if you are already +familiar with the +[TensorFlow Lite APIs](https://www.tensorflow.org/api_docs/python/tf/lite), +download the starter model and supporting files (if applicable). -#### Screenshot +Download +starter model -The following screenshot shows the Android image classification example. +## Model description -Screenshot of Android example +### How it works -## What is image classification? +During training, an image classification model is fed images and their +associated _labels_. Each label is the name of a distinct concept, or class, +that the model will learn to recognize. -A common use of machine learning is to identify what an image represents. For -example, we might want to know what type of animal appears in the following -photograph. +Given sufficient training data (often hundreds or thousands of images per +label), an image classification model can learn to predict whether new images +belong to any of the classes it has been trained on. This process of prediction +is called _inference_. Note that you can also use +[transfer learning](https://www.tensorflow.org/tutorials/images/transfer_learning) +to identify new classes of images by using a pre-existing model. Transfer +learning does not require a very large training dataset. -dog - -The task of predicting what an image represents is called _image -classification_. An image classification model is trained to recognize various -classes of images. For example, a model might be trained to recognize photos -representing three different types of animals: rabbits, hamsters, and dogs. - -When we subsequently provide a new image as input to the model, it will output +When you subsequently provide a new image as input to the model, it will output the probabilities of the image representing each of the types of animal it was trained on. An example output might be as follows: @@ -108,63 +90,10 @@ trained on. An example output might be as follows: -Based on the output, we can see that the classification model has predicted that -the image has a high probability of representing a dog. - -Note: Image classification can only tell you the probability that an image -represents one or more of the classes that the model was trained on. It cannot -tell you the position or identity of objects within the image. If you need to -identify objects and their positions within images, you should use an -object detection model. - -### Training, labels, and inference - -During training, an image classification model is fed images and their -associated _labels_. Each label is the name of a distinct concept, or class, -that the model will learn to recognize. - -Given sufficient training data (often hundreds or thousands of images per -label), an image classification model can learn to predict whether new images -belong to any of the classes it has been trained on. This process of prediction -is called _inference_. - -To perform inference, an image is passed as input to a model. The model will -then output an array of probabilities between 0 and 1. With our example model, -this process might look like the following: - - - - - - -
dog[0.07, 0.02, 0.91]
- -Each number in the output corresponds to a label in our training data. -Associating our output with the three labels the model was trained on, we can -see the model has predicted a high probability that the image represents a dog. - - - - - - - - - - - - - - - - - - - - - - -
LabelProbability
rabbit0.07
hamster0.02
dog0.91
+Each number in the output corresponds to a label in the training data. +Associating the output with the three labels the model was trained on, you can +see that the model has predicted a high probability that the image represents a +dog. You might notice that the sum of all the probabilities (for rabbit, hamster, and dog) is equal to 1. This is a common type of output for models with multiple @@ -172,12 +101,18 @@ classes (see Softmax for more information). -### Ambiguous results +Note: Image classification can only tell you the probability that an image +represents one or more of the classes that the model was trained on. It cannot +tell you the position or identity of objects within the image. If you need to +identify objects and their positions within images, you should use an +object detection model. -Since the probabilities will always sum to 1, if the image is not confidently -recognized as belonging to any of the classes the model was trained on you may -see the probability distributed throughout the labels without any one value -being significantly larger. +

Ambiguous results

+ +Since the output probabilities will always sum to 1, if an image is not +confidently recognized as belonging to any of the classes the model was trained +on you may see the probability distributed throughout the labels without any one +value being significantly larger. For example, the following might indicate an ambiguous result: @@ -203,13 +138,29 @@ For example, the following might indicate an ambiguous result: +If your model frequently returns ambiguous results, you may need a different, +more accurate model. -### Uses and limitations +

Choosing a model architecture

-The image classification models that we provide are useful for single-label -classification, which means predicting which single label the image is most -likely to represent. They are trained to recognize 1000 classes of image. For a -full list of classes, see the labels file in the +TensorFlow Lite provides you with a variety of image classification models which +are all trained on the original dataset. Model architectures like MobileNet, +Inception, and NASNet are available on the +hosted models page. To choose the best model for +your use case, you need to consider the individual architectures as well as some +of the tradeoffs between various models. Some of these model tradeoffs are based +on metrics such as performance, accuracy, and model size. For example, you might +need a faster model for building a bar code scanner while you might prefer a +slower, more accurate model for a medical imaging app. + +Note that the image classification models provided accept varying sizes of input. For some models, this is indicated in the filename. For example, the Mobilenet_V1_1.0_224 model accepts an input of 224x224 pixels. All of the models require three color channels per pixel (red, green, and blue). Quantized models require 1 byte per channel, and float models require 4 bytes per channel. The Android and iOS code samples demonstrate how to process full-sized camera images into the required format for each model. + +

Uses and limitations

+ +The TensorFlow Lite image classification models are useful for single-label +classification; that is, predicting which single label the image is most likely to +represent. They are trained to recognize 1000 image classes. For a full list of +classes, see the labels file in the model zip. @@ -225,13 +176,43 @@ For the following use cases, you should use a different type of model: Once you have the starter model running on your target device, you can experiment with different models to find the optimal balance between -performance, accuracy, and model size. For guidance, see -Choose a different model. +performance, accuracy, and model size. -## Performance benchmarks +

Customize model

-Performance benchmark numbers are generated with the tool -[described here](https://www.tensorflow.org/lite/performance/benchmarks). +The pre-trained models provided are trained to recognize 1000 classes of images. +For a full list of classes, see the labels file in the +model +zip. + +You can also use transfer learning to re-train a model to +recognize classes not in the original set. For example, you could re-train the +model to distinguish between different species of tree, despite there being no +trees in the original training data. To do this, you will need a set of training +images for each of the new labels you wish to train. + +Learn how to perform transfer learning in the +Recognize +flowers with TensorFlow codelab, or with the +Model Maker library. + +

Performance benchmarks

+ +Model performance is measured in terms of the amount of time it takes for a +model to run inference on a given piece of hardware. The lower the time, the faster +the model. + +The performance you require depends on your application. Performance can be +important for applications like real-time video, where it may be important to +analyze each frame in the time before the next frame is drawn (e.g. inference +must be faster than 33ms to perform real-time inference on a 30fps video +stream). + +The TensorFlow Lite quantized MobileNet models' performance range from 3.7ms to +80.3 ms. + +Performance benchmark numbers are generated with the +benchmarking tool. @@ -270,75 +251,35 @@ Performance benchmark numbers are generated with the tool \*\* 2 threads used on iPhone for the best performance result. -## Choose a different model +### Model accuracy -A large number of image classification models are available on our -List of hosted models. You should aim -to choose the optimal model for your application based on performance, accuracy -and model size. There are trade-offs between each of them. - -### Performance - -We measure performance in terms of the amount of time it takes for a model to -run inference on a given piece of hardware. The less time, the faster the model. - -The performance you require depends on your application. Performance can be -important for applications like real-time video, where it may be important to -analyze each frame in the time before the next frame is drawn (e.g. inference -must be faster than 33ms to perform real-time inference on a 30fps video -stream). - -Our quantized MobileNet models’ performance ranges from 3.7ms to 80.3 ms. - -### Accuracy - -We measure accuracy in terms of how often the model correctly classifies an +Accuracy is measured in terms of how often the model correctly classifies an image. For example, a model with a stated accuracy of 60% can be expected to classify an image correctly an average of 60% of the time. -Our list of hosted models provides -Top-1 and Top-5 accuracy statistics. Top-1 refers to how often the correct label -appears as the label with the highest probability in the model’s output. Top-5 -refers to how often the correct label appears in the top 5 highest probabilities -in the model’s output. +The [list of hosted models](../../guide/hosted_models.md) provides Top-1 and +Top-5 accuracy statistics. Top-1 refers to how often the correct label appears +as the label with the highest probability in the model’s output. Top-5 refers to +how often the correct label appears in the 5 highest probabilities in the +model’s output. -Our quantized MobileNet models’ Top-5 accuracy ranges from 64.4 to 89.9%. +The TensorFlow Lite quantized MobileNet models’ Top-5 accuracy range from 64.4 +to 89.9%. -### Size +### Model size The size of a model on-disk varies with its performance and accuracy. Size may be important for mobile development (where it might impact app download sizes) or when working with hardware (where available storage might be limited). -Our quantized MobileNet models’ size ranges from 0.5 to 3.4 Mb. +The TensorFlow Lite quantized MobileNet models' sizes range from 0.5 to 3.4 MB. -### Architecture +## Further reading and resources -Several different model architectures are available on -List of hosted models, indicated by -the model’s name. For example, you can choose between MobileNet, Inception, and -others. +Use the following resources to learn more about concepts related to image +classification: -The architecture of a model impacts its performance, accuracy, and size. All of -our hosted models are trained on the same data, meaning you can use the provided -statistics to compare them and choose which is optimal for your application. - -Note: The image classification models we provide accept varying sizes of input. For some models, this is indicated in the filename. For example, the Mobilenet_V1_1.0_224 model accepts an input of 224x224 pixels.

All of the models require three color channels per pixel (red, green, and blue). Quantized models require 1 byte per channel, and float models require 4 bytes per channel.

Our Android and iOS code samples demonstrate how to process full-sized camera images into the required format for each model. - -## Customize model - -The pre-trained models we provide are trained to recognize 1000 classes of -image. For a full list of classes, see the labels file in the -model -zip. - -You can use a technique known as _transfer learning_ to re-train a model to -recognize classes not in the original set. For example, you could re-train the -model to distinguish between different species of tree, despite there being no -trees in the original training data. To do this, you will need a set of training -images for each of the new labels you wish to train. - -Learn how to perform transfer learning in the -Recognize -flowers with TensorFlow codelab, or with the -[model maker toolkit](/lite/tutorials/model_maker_image_classification). +* [Image classification using TensorFlow](https://www.tensorflow.org/tutorials/images/classification) +* [Image classification with CNNs](https://www.tensorflow.org/tutorials/images/cnn) +* [Transfer learning](https://www.tensorflow.org/tutorials/images/transfer_learning) +* [Data augmentation](https://www.tensorflow.org/tutorials/images/data_augmentation) diff --git a/tensorflow/lite/g3doc/models/object_detection/overview.md b/tensorflow/lite/g3doc/models/object_detection/overview.md index cbe2fc05922..c4dfd4fd429 100644 --- a/tensorflow/lite/g3doc/models/object_detection/overview.md +++ b/tensorflow/lite/g3doc/models/object_detection/overview.md @@ -14,18 +14,13 @@ annotated: ## Get started -If you are new to TensorFlow Lite and are working with Android or iOS, download -the following example applications to get started. - -Android -example -iOS -example +To learn how to use object detection in a mobile app, explore the +Example applications and guides. If you are using a platform other than Android or iOS, or if you are already familiar with the TensorFlow Lite -APIs, you can download the starter object detection model and the +APIs, you can download our starter object detection model and the accompanying labels. Download @@ -45,6 +40,38 @@ For the following use cases, you should use a different type of model:
  • Predicting the composition of an image, for example subject versus background (see segmentation)
  • +### Example applications and guides + +If you are new to TensorFlow Lite and are working with Android or iOS, we +recommend exploring the following example applications that can help you get +started. + +#### Android + +You can leverage the out-of-box API from TensorFlow Lite Task Library to +[integrate object detection models](../../inference_with_metadata/task_library/object_detector) +in just a few lines of code. You can also +[build your own custom inference pipeline](../../guide/inference#load_and_run_a_model_in_java) +using the TensorFlow Lite Interpreter Java API. + +The Android example below demonstrates the implementation for both methods as +[lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android/lib_task_api) +and +[lib_interpreter](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android/lib_interpreter), +respectively. + +View +Android example + +#### iOS + +You can integrate the model using the +[TensorFlow Lite Interpreter Swift API](../../guide/inference#load_and_run_a_model_in_swift). +See the iOS example below. + +View +iOS example + ## Model description This section describes the signature for diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md index 6b233075398..b17c9c35fec 100644 --- a/tensorflow/lite/g3doc/performance/delegates.md +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -1,31 +1,61 @@ -# TensorFlow Lite delegates +# TensorFlow Lite Delegates -Note: Delegate API is still experimental and is subject to change. +## Introduction -## What is a TensorFlow Lite delegate? +**Delegates** enable hardware acceleration of TensorFlow Lite models by +leveraging on-device accelerators such as the GPU and +[Digital Signal Processor (DSP)](https://en.wikipedia.org/wiki/Digital_signal_processor). -A TensorFlow Lite delegate is a way to delegate part or all of graph execution -to another executor. +By default, TensorFlow Lite utilizes CPU kernels that are optimized for the +[ARM Neon](https://developer.arm.com/documentation/dht0002/a/Introducing-NEON/NEON-architecture-overview/NEON-instructions) +instruction set. However, the CPU is a multi-purpose processor that isn't +necessarily optimized for the heavy arithmetic typically found in Machine +Learning models (for example, the matrix math involved in convolution and dense +layers). -## Why should I use delegates? +On the other hand, most modern mobile phones contain chips that are better at +handling these heavy operations. Utilizing them for neural network operations +provides huge benefits in terms of latency and power efficiency. For example, +GPUs can provide upto a +[5x speedup](https://blog.tensorflow.org/2020/08/faster-mobile-gpu-inference-with-opencl.html) +in latency, while the +[Qualcomm® Hexagon DSP](https://developer.qualcomm.com/software/hexagon-dsp-sdk/dsp-processor) +has shown to reduce power consumption upto 75% in our experiments. -Running inference on compute-heavy machine learning models on mobile devices is -resource demanding due to the devices' limited processing and power. +Each of these accelerators have associated APIs that enable custom computations, +such as [OpenCL](https://www.khronos.org/opencl/) or +[OpenGL ES](https://www.khronos.org/opengles/) for mobile GPU and the +[Qualcomm® Hexagon SDK](https://developer.qualcomm.com/software/hexagon-dsp-sdk) +for DSP. Typically, you would have to write a lot of custom code to run a neural +network though these interfaces. Things get even complicated when you consider +that each accelerator has its pros & cons and cannot execute every operation in +a neural network. TensorFlow Lite's Delegate API solves this problem by acting +as a bridge between the TFLite runtime and these lower-level APIs. -Instead of relying on the CPU, some devices have hardware accelerators, such as -GPU or DSP, that allows for better performance and higher energy efficiency. +![runtime with delegates](images/delegate_runtime.png) -## Using the built-in delegates +## Choosing a Delegate -TensorFlow Lite provides the following delegates for hardware acceleration: +TensorFlow Lite supports multiple delegates, each of which is optimized for +certain platform(s) and particular types of models. Usually, there will be +multiple delegates applicable to your use-case, depending on two major criteria: +the *Platform* (Android or iOS?) you target, and the *Model-type* +(floating-point or quantized?) that you are trying to accelerate. + +### Delegates by Platform + +#### Cross-platform (Android & iOS) + +* **GPU delegate** - The GPU delegate can be used on both Android and iOS. It + is optimized to run 32-bit and 16-bit float based models where a GPU is + available. It also supports 8-bit quantized models and provides GPU + performance on par with their float versions. For details on the GPU + delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). For step-by-step + tutorials on using the GPU delegate with Android and iOS, see + [TensorFlow Lite GPU Delegate Tutorial](gpu.md). + +#### Android -* **GPU delegate for cross platform acceleration** - The GPU delegate can be - used on both Android and iOS. It is optimized to run 32-bit and 16-bit float - based models where a GPU is available. It also supports 8-bit quantized - models and provides GPU performance on par with their float versions. For - details on the GPU delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). - For step-by-step tutorials on using the GPU delegate with Android and iOS, - see [TensorFlow Lite GPU Delegate Tutorial](gpu.md). * **NNAPI delegate for newer Android devices** - The NNAPI delegate can be used to accelerate models on Android devices with GPU, DSP and / or NPU available. It is available in Android 8.1 (API 27+) or higher. For an @@ -33,210 +63,188 @@ TensorFlow Lite provides the following delegates for hardware acceleration: practices, see [TensorFlow Lite NNAPI delegate](nnapi.md). * **Hexagon delegate for older Android devices** - The Hexagon delegate can be used to accelerate models on Android devices with Qualcomm Hexagon DSP. It - can be used on devices older version of Android OS that does not fully - support NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) - for more detail. + can be used on devices running older versions of Android that do not support + NNAPI. See [TensorFlow Lite Hexagon delegate](hexagon_delegate.md) for more + detail. + +#### iOS + * **Core ML delegate for newer iPhones and iPads** - For newer iPhones and iPads where Neural Engine is available, you can use Core ML delegate to - accelerate inference for 32-bit float based models. Neural Engine is - available Apple mobile devices with A12 SoC or higher. For an overview of - the Core ML delegate and step-by-step instructions, see + accelerate inference for 32-bit or 16-bit floating-point models. Neural + Engine is available Apple mobile devices with A12 SoC or higher. For an + overview of the Core ML delegate and step-by-step instructions, see [TensorFlow Lite Core ML delegate](coreml_delegate.md). -## How do delegates work? +### Delegates by model type -Let's say we have a simple model graph such as the following: +Each accelerator is designed with a certain bit-width of data in mind. If you +provide a floating-point model to a delegate that only supports 8-bit quantized +operations (such as the [Hexagon delegate](hexagon_delegate.md)), it will reject +all its operations and the model will run entirely on the CPU. To avoid such +surprises, the table below provides an overview of delegate support based on +model type: -![Original graph](../images/performance/tflite_delegate_graph_1.png "Original Graph") +**Model Type** | **GPU** | **NNAPI** | **Hexagon** | **CoreML** +------------------------------------------------------------------------------------------------------- | ------- | --------- | ----------- | ---------- +Floating-point (32 bit) | Yes | Yes | No | Yes +[Post-training float16 quantization](post_training_float16_quant.ipynb) | Yes | No | No | Yes +[Post-training dynamic range quantization](post_training_quant.ipynb) | Yes | Yes | No | No +[Post-training integer quantization](post_training_integer_quant.ipynb) | Yes | Yes | Yes | No +[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Yes | Yes | Yes | No -If a delegate was provided for specific operations, then TensorFlow Lite will -split the graph into multiple subgraphs where each subgraph will be handled by a -delegate. +### Validating performance -Let's assume that a delegate, `MyDelegate`, has a faster implementation for -Conv2D and Mean operations. The resulting main graph will be updated to look -like below. +The information in this section acts as a rough guideline for shortlisting the +delegates that could improve your application. However, it is important to note +that each delegate has a pre-defined set of operations it supports, and may +perform differently depending on the model and device; for example, the +[NNAPI delegate](nnapi.md) may choose to use Google's Edge-TPU on a Pixel phone +while utilizing a DSP on another device. Therefore, it is usually recommended +that you perform some benchmarking to gauge how useful a delegate is for your +needs. This also helps justify the binary size increase associated with +attaching a delegate to the TensorFlow Lite runtime. -![Graph with delegate](../images/performance/tflite_delegate_graph_2.png "Graph with delegate") +TensorFlow Lite has extensive performance and accuracy-evaluation tooling that +can empower developers to be confident in using delegates in their application. +These tools are discussed in the next section. -Each subgraph that is handled by a delegate will be replaced with a node that -evaluates the subgraph on its invoked call. +## Tools for Evaluation -Depending on the model, the final graph can end up with one node, which means -that all of the graphs were delegated or multiple nodes handled the subgraphs. -In general, you don't want to have multiple subgraphs handled by the delegate, -since each time you switch from delegate to the main graph, there is an overhead -for passing the results from the subgraph to the main graph. It's not always -safe to share memory. +### Latency & memory footprint -## How to add a delegate +TensorFlow Lite’s +[benchmark tool](https://www.tensorflow.org/lite/performance/measurement) can be +used with suitable parameters to estimate model performance, including average +inference latency, initialization overhead, memory footprint, etc. This tool +supports multiple flags to figure out the best delegate configuration for your +model. For instance, `--gpu_backend=gl` can be specified with `--use_gpu` to +measure GPU execution with OpenGL. The complete list of supported delegate +parameters is defined in the +[detailed documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar). -_Note that the API used below is experimental and is subject to change._ +Here’s an example run for a quantized model with GPU via `adb`: -Based on the previous section, to add a delegate, we need to do the following: - -1. Define a kernel node that is responsible for evaluating the delegate - subgraph. -1. Create an instance of - [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), - which is responsible for registering the kernel node and claiming the nodes - that the delegate can execute. - -To see it in code, let's define a delegate and call it `MyDelegate`, which can -execute Conv2D and Mean operations faster. - -```c++ -#include "tensorflow/lite/util.h" -#include "tensorflow/lite/builtin_ops.h" -#include "tensorflow/lite/context_util.h" - -// This is where the execution of the operations or whole graph happens. -// The class below has an empty implementation just as a guideline -// on the structure. -class MyDelegate { - public: - // Returns true if my delegate can handle this type of op. - static bool SupportedOp(const TfLiteRegistration* registration) { - switch (registration->builtin_code) { - case kTfLiteBuiltinConv2d: - case kTfLiteBuiltinMean: - return true; - default: - return false; - } - } - - // Any initialization code needed - bool Init() {} - // Any preparation work needed (e.g. allocate buffers) - bool Prepare(TfLiteContext* context, TfLiteNode* node) {} - // Actual running of the delegate subgraph. - bool Invoke(TfLiteContext* context, TfLiteNode* node) {} - // ... Add any other methods needed. -}; - -// Create the TfLiteRegistration for the Kernel node which will replace -// the subgraph in the main TfLite graph. -TfLiteRegistration GetMyDelegateNodeRegistration() { - // This is the registration for the Delegate Node that gets added to - // the TFLite graph instead of the subgraph it replaces. - // It is treated as an OP node. But in our case - // Init will initialize the delegate. - // Invoke will run the delegate graph. - // Prepare for preparing the delegate. - // Free for any cleaning needed by the delegate. - TfLiteRegistration kernel_registration; - kernel_registration.builtin_code = kTfLiteBuiltinDelegate; - kernel_registration.custom_name = "MyDelegate"; - kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { - delete reinterpret_cast(buffer); - }; - kernel_registration.init = [](TfLiteContext* context, const char* buffer, - size_t) -> void* { - // In the node init phase, initialize MyDelegate instance - const TfLiteDelegateParams* delegate_params = - reinterpret_cast(buffer); - MyDelegate* my_delegate = new MyDelegate; - if (!my_delegate->Init(context, params)) { - return nullptr; - } - return my_delegate; - }; - kernel_registration.invoke = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - MyDelegate* kernel = reinterpret_cast(node->user_data); - return kernel->Invoke(context, node); - }; - kernel_registration.prepare = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - MyDelegate* kernel = reinterpret_cast(node->user_data); - return kernel->Prepare(context, node); - }; - - return kernel_registration; -} - -// TfLiteDelegate methods - -TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { - // Claim all nodes that can be evaluated by the delegate and ask the - // framework to update the graph with delegate kernel instead. - std::vector supported_nodes; - TfLiteIntArray* plan; - TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); - TfLiteNode* node; - TfLiteRegistration* registration; - for (int node_index : TfLiteIntArrayView(plan)) { - TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( - context, node_index, &node, ®istration)); - if (MyDelegate::SupportedOp(registration)) { - supported_nodes.push_back(node_index); - } - } - TfLiteRegistration my_delegate_kernel_registration = - GetMyDelegateNodeRegistration(); - - // This call split the graphs into subgraphs, for subgraphs that can be - // handled by the delegate, it will replace it with a - // 'my_delegate_kernel_registration' - TfLiteIntArray* supported_nodes_int_array = - ::tflite::ConvertVectorToTfLiteIntArray(supported_nodes); - auto status = context->ReplaceNodeSubsetsWithDelegateKernels( - context, my_delegate_kernel_registration, - supported_nodes_int_array, delegate); - TfLiteIntArrayFree(supported_nodes_int_array); - return status -} - -void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate, - TfLiteBufferHandle* handle) { - // Do any cleanups. -} - -TfLiteStatus CopyToBufferHandle(TfLiteContext* context, - TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor) { - // Copies data from tensor to delegate buffer if needed. - return kTfLiteOk; -} - -TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, - TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor) { - // Copies the data from delegate buffer into the tensor raw memory. - return kTfLiteOk; -} - -// Caller takes ownership of the returned pointer. -TfLiteDelegate* CreateMyDelegate() { - TfLiteDelegate* delegate = new TfLiteDelegate; - - delegate->data_ = nullptr; - delegate->flags = kTfLiteDelegateFlagsNone; - delegate->Prepare = &DelegatePrepare; - // This cannot be null. - delegate->CopyFromBufferHandle = &CopyFromBufferHandle; - // This can be null. - delegate->CopyToBufferHandle = &CopyToBufferHandle; - // This can be null. - delegate->FreeBufferHandle = &FreeBufferHandle; - - return delegate; -} - - -// To add the delegate you need to call - -auto* my_delegate = CreateMyDelegate(); -if (interpreter->ModifyGraphWithDelegate(my_delegate) != - kTfLiteOk) { - // Handle error -} else { - interpreter->Invoke(); -} -... -// Don't forget to delete your delegate -delete my_delegate; ``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_v1_224_quant.tflite \ + --use_gpu=true +``` + +You can download pre-built version of this tool for Android, 64-bit ARM +architecture +[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model.apk) +([more details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android)). + +### Accuracy & correctness + +Delegates usually perform computations at a different precision than their CPU +counterparts. As a result, there is an (usually minor) accuracy tradeoff +associated with utilizing a delegate for hardware acceleration. Note that this +isn't *always* true; for example, since the GPU uses floating-point precision to +run quantized models, there might be a slight precision improvement (for e.g., +<1% Top-5 improvement in ILSVRC image classification). + +TensorFlow Lite has two types of tooling to measure how accurately a delegate +behaves for a given model: *Task-Based* and *Task-Agnostic*. All the tools +described in this section support the +[advanced delegation parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar) +used by the benchmarking tool from the previous section. Note that the +sub-sections below focus on *delegate evaluation* (Does the delegate perform the +same as the CPU?) rather than model evaluation (Is the model itself good for the +task?). + +#### Task-Based Evaluation + +TensorFlow Lite has tools to evaluate correctness on two image-based tasks: + +* [ILSVRC 2012](http://image-net.org/challenges/LSVRC/2012/) (Image + Classification) with + [top-K accuracy](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Precision_at_K) + +* [COCO Object Detection (w/ bounding boxes)](https://cocodataset.org/#detection-2020) + with + [mean Average Precision (mAP)](https://en.wikipedia.org/wiki/Evaluation_measures_\(information_retrieval\)#Mean_average_precision) + +Prebuilt binaries of these tools (Android, 64-bit ARM architecture), along with +documentation can be found here: + +* [ImageNet Image Classification](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_imagenet_image_classification) + ([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification)) +* [COCO Object Detection](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_coco_object_detection) + ([More details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/coco_object_detection)) + +The example below demonstrates +[image classification evaluation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification) +with NNAPI utilizing Google's Edge-TPU on a Pixel 4: + +``` +adb shell /data/local/tmp/run_eval \ + --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --ground_truth_images_path=/data/local/tmp/ilsvrc_images \ + --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \ + --model_output_labels=/data/local/tmp/model_output_labels.txt \ + --output_file_path=/data/local/tmp/accuracy_output.txt \ + --num_images=0 # Run on all images. \ + --use_nnapi=true \ + --nnapi_accelerator_name=google-edgetpu +``` + +The expected output is a list of Top-K metrics from 1 to 10: + +``` +Top-1 Accuracy: 0.733333 +Top-2 Accuracy: 0.826667 +Top-3 Accuracy: 0.856667 +Top-4 Accuracy: 0.87 +Top-5 Accuracy: 0.89 +Top-6 Accuracy: 0.903333 +Top-7 Accuracy: 0.906667 +Top-8 Accuracy: 0.913333 +Top-9 Accuracy: 0.92 +Top-10 Accuracy: 0.923333 +``` + +#### Task-Agnostic Evaluation + +For tasks where there isn't an established on-device evaluation tool, or if you +are experimenting with custom models, TensorFlow Lite has the +[Inference Diff](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/inference_diff) +tool. (Android, 64-bit ARM binary architecture binary +[here](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_eval_inference_diff)) + +Inference Diff compares TensorFlow Lite execution (in terms of latency & +output-value deviation) in two settings: + +* Single-threaded CPU Inference +* User-defined Inference - defined by + [these parameters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/delegates/README.md#tflite-delegate-registrar) + +To do so, the tool generates random Gaussian data and passes it through two +TFLite Interpreters - one running single-threaded CPU kernels, and the other +parametrized by the user's arguments. + +It measures the latency of both, as well as the absolute difference between the +output tensors from each Interpreter, on a per-element basis. + +For a model with a single output tensor, the output might look like this: + +``` +Num evaluation runs: 50 +Reference run latency: avg=84364.2(us), std_dev=12525(us) +Test run latency: avg=7281.64(us), std_dev=2089(us) +OutputDiff[0]: avg_error=1.96277e-05, std_dev=6.95767e-06 +``` + +What this means is that for the output tensor at index `0`, the elements from +the CPU output different from the delegate output by an average of `1.96e-05`. + +Note that interpreting these numbers requires deeper knowledge of the model, and +what each output tensor signifies. If its a simple regression that determines +some sort of score or embedding, the difference should be low (otherwise it's an +error with the delegate). However, outputs like the 'detection class' one from +SSD models is a little harder to interpret. For example, it might show a +difference using this tool, but that may not mean something really wrong with +the delegate: consider two (fake) classes: "TV (ID: 10)", "Monitor (ID:20)" - If +a delegate is slightly off the golden truth and shows monitor instead of TV, the +output diff for this tensor might be something as high as 20-10 = 10. diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md index 96e8aa6f9dc..077f88e1b12 100644 --- a/tensorflow/lite/g3doc/performance/gpu.md +++ b/tensorflow/lite/g3doc/performance/gpu.md @@ -12,11 +12,9 @@ resulting in lower latency. In the best scenario, inference on the GPU may now run fast enough for previously not available real-time applications. Unlike CPUs, GPUs compute with 16-bit or 32-bit floating point numbers and do -not require quantization for optimal performance. - -**NOTE:** The delegate does accept 8-bit quantized models on Android. Support on -iOS is experimental. Refer to the [advanced documentation](gpu_advanced.md) for -details. +not require quantization for optimal performance. The delegate does accept 8-bit +quantized models, but the calculation will be performed in floating point +numbers. Refer to the [advanced documentation](gpu_advanced.md) for details. Another benefit with GPU inference is its power efficiency. GPUs carry out the computations in a very efficient and optimized manner, so that they consume less diff --git a/tensorflow/lite/g3doc/performance/gpu_advanced.md b/tensorflow/lite/g3doc/performance/gpu_advanced.md index 71415693f86..d23c87c8288 100644 --- a/tensorflow/lite/g3doc/performance/gpu_advanced.md +++ b/tensorflow/lite/g3doc/performance/gpu_advanced.md @@ -325,7 +325,6 @@ Android APIs support quantized models by default. To disable, do the following: **C++ API** ```c++ -// NEW: Prepare custom options with feature enabled. TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; @@ -336,7 +335,6 @@ if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; **Java API** ```java -// NEW: Prepare GPU delegate with feature turned on. GpuDelegate delegate = new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(false)); Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate); @@ -344,27 +342,21 @@ Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate); #### iOS -Support for quantized models on iOS APIs is experimental. To enable, do the -following: +iOD APIs support quantized models by default. To disable, do the following: **Swift API** ```swift -// NEW: Prepare custom options with feature enabled. var options = MetalDelegate.Options() -options.isQuantizationEnabled = true +options.isQuantizationEnabled = false let delegate = MetalDelegate(options: options) ``` **C API (also used for Objective-C)** ```c - -// THIS: -// NEW: Prepare custom options with feature enabled. -const TFLGpuDelegateOptions options = { - .enable_quantization = true, -}; +TFLGpuDelegateOptions options = TFLGpuDelegateOptionsDefault(); +options.enable_quantization = false; auto* delegate = TFLGpuDelegateCreate(options); ``` diff --git a/tensorflow/lite/g3doc/performance/images/delegate_runtime.png b/tensorflow/lite/g3doc/performance/images/delegate_runtime.png new file mode 100644 index 00000000000..e229f0fda09 Binary files /dev/null and b/tensorflow/lite/g3doc/performance/images/delegate_runtime.png differ diff --git a/tensorflow/lite/g3doc/performance/implementing_delegate.md b/tensorflow/lite/g3doc/performance/implementing_delegate.md new file mode 100644 index 00000000000..85904cad091 --- /dev/null +++ b/tensorflow/lite/g3doc/performance/implementing_delegate.md @@ -0,0 +1,171 @@ +# Implementing a Delegate + +Note: The API used below is experimental and is subject to change. + +Follow the steps below to add a delegate: + +1. Define a kernel node that is responsible for evaluating the delegate + subgraph. +1. Create an instance of + [TfLiteDelegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h#L611), + which is responsible for registering the kernel node and claiming the nodes + that the delegate can execute. + +To see it in code, define a delegate `MyDelegate` to execute Conv2D and Mean ops +faster. + +```c++ +#include "tensorflow/lite/util.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/context_util.h" + +// This is where the execution of the operations or whole graph happens. +// The class below has an empty implementation just as a guideline +// on the structure. +class MyDelegate { + public: + // Returns true if MyDelegate can handle this type of op. + static bool SupportedOp(const TfLiteRegistration* registration) { + switch (registration->builtin_code) { + case kTfLiteBuiltinConv2d: + case kTfLiteBuiltinMean: + return true; + default: + return false; + } + } + + // Any initialization code needed + bool Init() {} + // Any preparation work needed (e.g. allocate buffers) + bool Prepare(TfLiteContext* context, TfLiteNode* node) {} + // Actual running of the delegate subgraph. + bool Invoke(TfLiteContext* context, TfLiteNode* node) {} + // ... Add any other methods needed. +}; + +// Create the TfLiteRegistration for the Kernel node which will replace +// the subgraph in the main TfLite graph. +TfLiteRegistration GetMyDelegateNodeRegistration() { + // This is the registration for the Delegate Node that gets added to + // the TFLite graph instead of the subgraph it replaces. + // It is treated as an OP node. But in this case + // Init initializes the delegate. + // Invoke runs the delegate graph. + // Prepare prepares the delegate. + // Free performs any memory cleanup needed by the delegate. + TfLiteRegistration kernel_registration; + kernel_registration.builtin_code = kTfLiteBuiltinDelegate; + kernel_registration.custom_name = "MyDelegate"; + kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }; + kernel_registration.init = [](TfLiteContext* context, const char* buffer, + size_t) -> void* { + // In the node init phase, initialize MyDelegate instance + const TfLiteDelegateParams* delegate_params = + reinterpret_cast(buffer); + MyDelegate* my_delegate = new MyDelegate; + if (!my_delegate->Init(context, params)) { + return nullptr; + } + return my_delegate; + }; + kernel_registration.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Invoke(context, node); + }; + kernel_registration.prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + MyDelegate* kernel = reinterpret_cast(node->user_data); + return kernel->Prepare(context, node); + }; + + return kernel_registration; +} + +// TfLiteDelegate methods + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Claim all nodes that can be evaluated by the delegate and ask the + // framework to update the graph with delegate kernel instead. + std::vector supported_nodes; + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + TfLiteNode* node; + TfLiteRegistration* registration; + for (int node_index : TfLiteIntArrayView(plan)) { + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + if (MyDelegate::SupportedOp(registration)) { + supported_nodes.push_back(node_index); + } + } + TfLiteRegistration my_delegate_kernel_registration = + GetMyDelegateNodeRegistration(); + + // This call split the graphs into subgraphs, for subgraphs that can be + // handled by the delegate, it will replace it with a + // 'my_delegate_kernel_registration' + TfLiteIntArray* supported_nodes_int_array = + ::tflite::ConvertVectorToTfLiteIntArray(supported_nodes); + auto status = context->ReplaceNodeSubsetsWithDelegateKernels( + context, my_delegate_kernel_registration, + supported_nodes_int_array, delegate); + TfLiteIntArrayFree(supported_nodes_int_array); + return status +} + +void FreeBufferHandle(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { + // Do any cleanups. +} + +TfLiteStatus CopyToBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies data from tensor to delegate buffer if needed. + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + // Copies the data from delegate buffer into the tensor raw memory. + return kTfLiteOk; +} + +// Caller takes ownership of the returned pointer. +TfLiteDelegate* CreateMyDelegate() { + TfLiteDelegate* delegate = new TfLiteDelegate; + + delegate->data_ = nullptr; + delegate->flags = kTfLiteDelegateFlagsNone; + delegate->Prepare = &DelegatePrepare; + // This cannot be null. + delegate->CopyFromBufferHandle = &CopyFromBufferHandle; + // This can be null. + delegate->CopyToBufferHandle = &CopyToBufferHandle; + // This can be null. + delegate->FreeBufferHandle = &FreeBufferHandle; + + return delegate; +} + + +// To add the delegate you need to call + +auto* my_delegate = CreateMyDelegate(); +if (interpreter->ModifyGraphWithDelegate(my_delegate) != + kTfLiteOk) { + // Handle error +} else { + interpreter->Invoke(); +} +... +// Don't forget to delete your delegate +delete my_delegate; +``` diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb index 43168e394f5..ba6d266361b 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb @@ -518,7 +518,7 @@ "source": [ "## Choose a `model_spec` that Represents a Model for Text Classifier\n", "\n", - "Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base]((https://arxiv.org/pdf/1810.04805.pdf) models.\n", + "Each `model_spec` object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf), averaging word embeddings and [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) models.\n", "\n", "Supported Model | Name of model_spec | Model Description\n", "--- | --- | ---\n", @@ -548,7 +548,7 @@ "source": [ "## Load Input Data Specific to an On-device ML App\n", "\n", - "The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark . It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n", + "The [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) is one of the tasks in the [GLUE](https://gluebenchmark.com/) benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.\n", "\n", "Download the archived version of the dataset and extract it.\n" ] @@ -669,9 +669,7 @@ "source": [ "## Evaluate the Customized Model\n", "\n", - "Evaluate the result of the model and get the loss and accuracy of the model.\n", - "\n", - "Evaluate the loss and accuracy in the test data." + "Evaluate the model with the test data and get its loss and accuracy." ] }, { @@ -749,7 +747,7 @@ "id": "HZKYthlVrTos" }, "source": [ - "You can evalute the tflite model with `evaluate_tflite` method." + "You can evalute the tflite model with `evaluate_tflite` method to get its accuracy." ] }, { @@ -760,7 +758,7 @@ }, "outputs": [], "source": [ - "model.evaluate_tflite('average_word_vec/model.tflite', test_data)" + "accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)" ] }, { @@ -771,7 +769,7 @@ "source": [ "## Advanced Usage\n", "\n", - "The `create` function is the driver function that the Model Maker library uses to create models. The `model spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n", + "The `create` function is the driver function that the Model Maker library uses to create models. The `model_spec` parameter defines the model specification. The `AverageWordVecModelSpec` and `BertClassifierModelSpec` classes are currently supported. The `create` function comprises of the following steps:\n", "\n", "1. Creates the model for the text classifier according to `model_spec`.\n", "2. Trains the classifier model. The default epochs and the default batch size are set by the `default_training_epochs` and `default_batch_size` variables in the `model_spec` object.\n", @@ -867,7 +865,7 @@ "The model parameters you can adjust are:\n", "\n", "* `seq_len`: Length of the sequence to feed into the model.\n", - "* `initializer_range`: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n", + "* `initializer_range`: The standard deviation of the `truncated_normal_initializer` for initializing all weight matrices.\n", "* `trainable`: Boolean that specifies whether the pre-trained layer is trainable.\n", "\n", "The training pipeline parameters you can adjust are:\n", diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 6a77c5a5f11..ef5831fda50 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -412,7 +412,11 @@ class Interpreter { /// 2. kTfLiteDelegateError: Delegation failed due to an error in the /// delegate. The Interpreter has been restored to its pre-delegation state. /// NOTE: This undoes all delegates previously applied to the Interpreter. - /// 3. kTfLiteError: Unexpected/runtime failure. + /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the + /// incompatibility with the TfLite runtime, e.g., the model graph is already + /// immutable when applying the delegate. However, the interpreter could still + /// be invoked. + /// 4. kTfLiteError: Unexpected/runtime failure. /// WARNING: This is an experimental API and subject to change. TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index 66728ea89e9..b70908e7162 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -49,6 +49,13 @@ class InterpreterTest : public ::testing::Test { protected: TfLiteContext* GetInterpreterContext() { return interpreter_.context_; } + std::vector* + mutable_lazy_delegate_providers() { + return &interpreter_.lazy_delegate_providers_; + } + + bool HasDelegates() { return interpreter_.HasDelegates(); } + Interpreter interpreter_; }; @@ -1782,6 +1789,63 @@ TEST_F(TestCustomAllocation, ResizeTensorsWithEnoughMemory) { VerifyInvoke(); } +// Tests related to lazy delegate providers that are primarily used for applying +// TfLite delegates by default. +class TestLazyDelegateProvider : public InterpreterTest { + protected: + struct DummyLazyDelegateProvider : public TfLiteDelegate { + explicit DummyLazyDelegateProvider(int64_t support_flags) { + data_ = static_cast(this); + flags = support_flags; + Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus { + return kTfLiteOk; + }; + } + }; + + void InitWithLazyDelegate(int64_t delegate_flags, + bool create_dyanmic_tensor = false) { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter_.AddTensors(2), kTfLiteOk); + interpreter_.SetInputs({0}); + interpreter_.SetOutputs({1}); + interpreter_.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + Interpreter::TfLiteDelegatePtr delegate( + new DummyLazyDelegateProvider(delegate_flags), + [](TfLiteDelegate* delegate) { + auto* dummy = + static_cast(delegate->data_); + delete dummy; + }); + mutable_lazy_delegate_providers()->push_back(std::move(delegate)); + + if (create_dyanmic_tensor) { + // Mark the output as dynamic tensor. + interpreter_.tensor(1)->data.raw = nullptr; + interpreter_.tensor(1)->allocation_type = kTfLiteDynamic; + } + } +}; + +TEST_F(TestLazyDelegateProvider, ApplicationSuccess) { + InitWithLazyDelegate(kTfLiteDelegateFlagsNone); + EXPECT_EQ(kTfLiteOk, interpreter_.AllocateTensors()); + // We clear Interpreter::lazy_delegate_providers_ after they are tried out. + EXPECT_TRUE(mutable_lazy_delegate_providers()->empty()); + EXPECT_TRUE(HasDelegates()); +} + +TEST_F(TestLazyDelegateProvider, ApplicationSkipped) { + InitWithLazyDelegate(kTfLiteDelegateFlagsNone, + true /* create_dyanmic_tensor */); + EXPECT_EQ(kTfLiteOk, interpreter_.AllocateTensors()); + EXPECT_TRUE(mutable_lazy_delegate_providers()->empty()); + // As the delegate doesn't allow dynamic tensor, the delegate won't be applied + // and the interpreter doesn't have any delegate applied. + EXPECT_FALSE(HasDelegates()); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index f7aa91dc24d..f3a639ad0c6 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -549,6 +549,7 @@ BUILTIN_KERNEL_SRCS = [ "comparisons.cc", "concatenation.cc", "conv.cc", + "cumsum.cc", "densify.cc", "depth_to_space.cc", "depthwise_conv.cc", @@ -717,7 +718,6 @@ cc_library( name = "custom_ops", srcs = [ "complex_support.cc", - "cumsum.cc", "multinomial.cc", "random_standard_normal.cc", "rfft2d.cc", @@ -2341,16 +2341,15 @@ cc_test( cc_test( name = "cumsum_test", + size = "small", srcs = ["cumsum_test.cc"], deps = [ - ":custom_ops", ":test_main", ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", - "@flatbuffers", ], ) diff --git a/tensorflow/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc index 6e0316538b9..20c1876a96d 100644 --- a/tensorflow/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -836,58 +836,102 @@ TEST_P(TanhOpTest, TanhInt16) { const float kMax = 32767.f / 32768.f; QuantizedActivationsOpModel m( GetRegistration(), BuiltinOperator_TANH, - /*input=*/{TensorType_INT16, {89}, 8 * kMin, 8 * kMax}, - /*output=*/{TensorType_INT16, {89}, kMin, kMax}); + /*input=*/{TensorType_INT16, {177}, 16 * kMin, 16 * kMax}, + /*output=*/{TensorType_INT16, {177}, kMin, kMax}); m.SetInput( - {-8.0000000000, -7.8181818182, -7.6363636364, -7.4545454545, - -7.2727272727, -7.0909090909, -6.9090909091, -6.7272727273, - -6.5454545455, -6.3636363636, -6.1818181818, -6.0000000000, - -5.8181818182, -5.6363636364, -5.4545454545, -5.2727272727, - -5.0909090909, -4.9090909091, -4.7272727273, -4.5454545455, - -4.3636363636, -4.1818181818, -4.0000000000, -3.8181818182, - -3.6363636364, -3.4545454545, -3.2727272727, -3.0909090909, - -2.9090909091, -2.7272727273, -2.5454545455, -2.3636363636, - -2.1818181818, -2.0000000000, -1.8181818182, -1.6363636364, - -1.4545454545, -1.2727272727, -1.0909090909, -0.9090909091, - -0.7272727273, -0.5454545455, -0.3636363636, -0.1818181818, - 0.0000000000, 0.1818181818, 0.3636363636, 0.5454545455, - 0.7272727273, 0.9090909091, 1.0909090909, 1.2727272727, - 1.4545454545, 1.6363636364, 1.8181818182, 2.0000000000, - 2.1818181818, 2.3636363636, 2.5454545455, 2.7272727273, - 2.9090909091, 3.0909090909, 3.2727272727, 3.4545454545, - 3.6363636364, 3.8181818182, 4.0000000000, 4.1818181818, - 4.3636363636, 4.5454545455, 4.7272727273, 4.9090909091, - 5.0909090909, 5.2727272727, 5.4545454545, 5.6363636364, - 5.8181818182, 6.0000000000, 6.1818181818, 6.3636363636, - 6.5454545455, 6.7272727273, 6.9090909091, 7.0909090909, - 7.2727272727, 7.4545454545, 7.6363636364, 7.8181818182, - 8.0000000000}); + {-20.0000000000, -19.7727272727, -19.5454545455, -19.3181818182, + -19.0909090909, -18.8636363636, -18.6363636364, -18.4090909091, + -18.1818181818, -17.9545454545, -17.7272727273, -17.5000000000, + -17.2727272727, -17.0454545455, -16.8181818182, -16.5909090909, + -16.3636363636, -16.1363636364, -15.9090909091, -15.6818181818, + -15.4545454545, -15.2272727273, -15.0000000000, -14.7727272727, + -14.5454545455, -14.3181818182, -14.0909090909, -13.8636363636, + -13.6363636364, -13.4090909091, -13.1818181818, -12.9545454545, + -12.7272727273, -12.5000000000, -12.2727272727, -12.0454545455, + -11.8181818182, -11.5909090909, -11.3636363636, -11.1363636364, + -10.9090909091, -10.6818181818, -10.4545454545, -10.2272727273, + -10.0000000000, -9.7727272727, -9.5454545455, -9.3181818182, + -9.0909090909, -8.8636363636, -8.6363636364, -8.4090909091, + -8.1818181818, -7.9545454545, -7.7272727273, -7.5000000000, + -7.2727272727, -7.0454545455, -6.8181818182, -6.5909090909, + -6.3636363636, -6.1363636364, -5.9090909091, -5.6818181818, + -5.4545454545, -5.2272727273, -5.0000000000, -4.7727272727, + -4.5454545455, -4.3181818182, -4.0909090909, -3.8636363636, + -3.6363636364, -3.4090909091, -3.1818181818, -2.9545454545, + -2.7272727273, -2.5000000000, -2.2727272727, -2.0454545455, + -1.8181818182, -1.5909090909, -1.3636363636, -1.1363636364, + -0.9090909091, -0.6818181818, -0.4545454545, -0.2272727273, + 0.0000000000, 0.2272727273, 0.4545454545, 0.6818181818, + 0.9090909091, 1.1363636364, 1.3636363636, 1.5909090909, + 1.8181818182, 2.0454545455, 2.2727272727, 2.5000000000, + 2.7272727273, 2.9545454545, 3.1818181818, 3.4090909091, + 3.6363636364, 3.8636363636, 4.0909090909, 4.3181818182, + 4.5454545455, 4.7727272727, 5.0000000000, 5.2272727273, + 5.4545454545, 5.6818181818, 5.9090909091, 6.1363636364, + 6.3636363636, 6.5909090909, 6.8181818182, 7.0454545455, + 7.2727272727, 7.5000000000, 7.7272727273, 7.9545454545, + 8.1818181818, 8.4090909091, 8.6363636364, 8.8636363636, + 9.0909090909, 9.3181818182, 9.5454545455, 9.7727272727, + 10.0000000000, 10.2272727273, 10.4545454545, 10.6818181818, + 10.9090909091, 11.1363636364, 11.3636363636, 11.5909090909, + 11.8181818182, 12.0454545455, 12.2727272727, 12.5000000000, + 12.7272727273, 12.9545454545, 13.1818181818, 13.4090909091, + 13.6363636364, 13.8636363636, 14.0909090909, 14.3181818182, + 14.5454545455, 14.7727272727, 15.0000000000, 15.2272727273, + 15.4545454545, 15.6818181818, 15.9090909091, 16.1363636364, + 16.3636363636, 16.5909090909, 16.8181818182, 17.0454545455, + 17.2727272727, 17.5000000000, 17.7272727273, 17.9545454545, + 18.1818181818, 18.4090909091, 18.6363636364, 18.8636363636, + 19.0909090909, 19.3181818182, 19.5454545455, 19.7727272727, + 20.0000000000}); m.Invoke(); EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {-0.9999997749, -0.9999996762, -0.9999995342, -0.9999993300, - -0.9999990361, -0.9999986134, -0.9999980053, -0.9999971306, - -0.9999958722, -0.9999940619, -0.9999914578, -0.9999877117, - -0.9999823226, -0.9999745703, -0.9999634183, -0.9999473758, - -0.9999242982, -0.9998911009, -0.9998433469, -0.9997746542, - -0.9996758446, -0.9995337191, -0.9993292997, -0.9990353053, - -0.9986125310, -0.9980046622, -0.9971308601, -0.9958751909, - -0.9940716137, -0.9914827859, -0.9877703933, -0.9824541388, - -0.9748561217, -0.9640275801, -0.9486568273, -0.9269625051, - -0.8965880154, -0.8545351057, -0.7972097087, -0.7206956332, - -0.6213939966, -0.4971057414, -0.3484130125, -0.1798408185, - 0.0000000000, 0.1798408185, 0.3484130125, 0.4971057414, - 0.6213939966, 0.7206956332, 0.7972097087, 0.8545351057, - 0.8965880154, 0.9269625051, 0.9486568273, 0.9640275801, - 0.9748561217, 0.9824541388, 0.9877703933, 0.9914827859, - 0.9940716137, 0.9958751909, 0.9971308601, 0.9980046622, - 0.9986125310, 0.9990353053, 0.9993292997, 0.9995337191, - 0.9996758446, 0.9997746542, 0.9998433469, 0.9998911009, - 0.9999242982, 0.9999473758, 0.9999634183, 0.9999745703, - 0.9999823226, 0.9999877117, 0.9999914578, 0.9999940619, - 0.9999958722, 0.9999971306, 0.9999980053, 0.9999986134, - 0.9999990361, 0.9999993300, 0.9999995342, 0.9999996762, - 0.9999997749}, + {-1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000, + -1.0000000000, -1.0000000000, -1.0000000000, -0.9999999999, + -0.9999999999, -0.9999999998, -0.9999999997, -0.9999999996, + -0.9999999993, -0.9999999989, -0.9999999983, -0.9999999974, + -0.9999999959, -0.9999999935, -0.9999999898, -0.9999999839, + -0.9999999746, -0.9999999600, -0.9999999370, -0.9999999007, + -0.9999998435, -0.9999997535, -0.9999996117, -0.9999993882, + -0.9999990361, -0.9999984815, -0.9999976076, -0.9999962309, + -0.9999940619, -0.9999906449, -0.9999852614, -0.9999767801, + -0.9999634183, -0.9999423677, -0.9999092043, -0.9998569589, + -0.9997746542, -0.9996450004, -0.9994407705, -0.9991190997, + -0.9986125310, -0.9978149744, -0.9965597488, -0.9945853915, + -0.9914827859, -0.9866142982, -0.9789923110, -0.9671021386, + -0.9486568273, -0.9202886021, -0.8772337852, -0.8131859906, + -0.7206956332, -0.5927001330, -0.4256281972, -0.2234388228, + 0.0000000000, 0.2234388228, 0.4256281972, 0.5927001330, + 0.7206956332, 0.8131859906, 0.8772337852, 0.9202886021, + 0.9486568273, 0.9671021386, 0.9789923110, 0.9866142982, + 0.9914827859, 0.9945853915, 0.9965597488, 0.9978149744, + 0.9986125310, 0.9991190997, 0.9994407705, 0.9996450004, + 0.9997746542, 0.9998569589, 0.9999092043, 0.9999423677, + 0.9999634183, 0.9999767801, 0.9999852614, 0.9999906449, + 0.9999940619, 0.9999962309, 0.9999976076, 0.9999984815, + 0.9999990361, 0.9999993882, 0.9999996117, 0.9999997535, + 0.9999998435, 0.9999999007, 0.9999999370, 0.9999999600, + 0.9999999746, 0.9999999839, 0.9999999898, 0.9999999935, + 0.9999999959, 0.9999999974, 0.9999999983, 0.9999999989, + 0.9999999993, 0.9999999996, 0.9999999997, 0.9999999998, + 0.9999999999, 0.9999999999, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000, + 1.0000000000}, kQuantizedToleranceInt16))); } @@ -1031,54 +1075,94 @@ TEST_P(LogisticOpTest, SigmoidInt16) { const float kMax = 32767.f / 32768.f; QuantizedActivationsOpModel m( GetRegistration(), BuiltinOperator_LOGISTIC, - /*input=*/{TensorType_INT16, {89}, 8 * kMin, 8 * kMax}, - /*output=*/{TensorType_INT16, {89}, kMin, kMax}); + /*input=*/{TensorType_INT16, {177}, 16 * kMin, 16 * kMax}, + /*output=*/{TensorType_INT16, {177}, kMin, kMax}); m.SetInput( - {-10.0000000000, -9.7727272727, -9.5454545455, -9.3181818182, - -9.0909090909, -8.8636363636, -8.6363636364, -8.4090909091, - -8.1818181818, -7.9545454545, -7.7272727273, -7.5000000000, - -7.2727272727, -7.0454545455, -6.8181818182, -6.5909090909, - -6.3636363636, -6.1363636364, -5.9090909091, -5.6818181818, - -5.4545454545, -5.2272727273, -5.0000000000, -4.7727272727, - -4.5454545455, -4.3181818182, -4.0909090909, -3.8636363636, - -3.6363636364, -3.4090909091, -3.1818181818, -2.9545454545, - -2.7272727273, -2.5000000000, -2.2727272727, -2.0454545455, - -1.8181818182, -1.5909090909, -1.3636363636, -1.1363636364, - -0.9090909091, -0.6818181818, -0.4545454545, -0.2272727273, - 0.0000000000, 0.2272727273, 0.4545454545, 0.6818181818, - 0.9090909091, 1.1363636364, 1.3636363636, 1.5909090909, - 1.8181818182, 2.0454545455, 2.2727272727, 2.5000000000, - 2.7272727273, 2.9545454545, 3.1818181818, 3.4090909091, - 3.6363636364, 3.8636363636, 4.0909090909, 4.3181818182, - 4.5454545455, 4.7727272727, 5.0000000000, 5.2272727273, - 5.4545454545, 5.6818181818, 5.9090909091, 6.1363636364, - 6.3636363636, 6.5909090909, 6.8181818182, 7.0454545455, - 7.2727272727, 7.5000000000, 7.7272727273, 7.9545454545, - 8.1818181818, 8.4090909091, 8.6363636364, 8.8636363636, - 9.0909090909, 9.3181818182, 9.5454545455, 9.7727272727, - 10.0000000000}); + {-20.0000000000, -19.7727272727, -19.5454545455, -19.3181818182, + -19.0909090909, -18.8636363636, -18.6363636364, -18.4090909091, + -18.1818181818, -17.9545454545, -17.7272727273, -17.5000000000, + -17.2727272727, -17.0454545455, -16.8181818182, -16.5909090909, + -16.3636363636, -16.1363636364, -15.9090909091, -15.6818181818, + -15.4545454545, -15.2272727273, -15.0000000000, -14.7727272727, + -14.5454545455, -14.3181818182, -14.0909090909, -13.8636363636, + -13.6363636364, -13.4090909091, -13.1818181818, -12.9545454545, + -12.7272727273, -12.5000000000, -12.2727272727, -12.0454545455, + -11.8181818182, -11.5909090909, -11.3636363636, -11.1363636364, + -10.9090909091, -10.6818181818, -10.4545454545, -10.2272727273, + -10.0000000000, -9.7727272727, -9.5454545455, -9.3181818182, + -9.0909090909, -8.8636363636, -8.6363636364, -8.4090909091, + -8.1818181818, -7.9545454545, -7.7272727273, -7.5000000000, + -7.2727272727, -7.0454545455, -6.8181818182, -6.5909090909, + -6.3636363636, -6.1363636364, -5.9090909091, -5.6818181818, + -5.4545454545, -5.2272727273, -5.0000000000, -4.7727272727, + -4.5454545455, -4.3181818182, -4.0909090909, -3.8636363636, + -3.6363636364, -3.4090909091, -3.1818181818, -2.9545454545, + -2.7272727273, -2.5000000000, -2.2727272727, -2.0454545455, + -1.8181818182, -1.5909090909, -1.3636363636, -1.1363636364, + -0.9090909091, -0.6818181818, -0.4545454545, -0.2272727273, + 0.0000000000, 0.2272727273, 0.4545454545, 0.6818181818, + 0.9090909091, 1.1363636364, 1.3636363636, 1.5909090909, + 1.8181818182, 2.0454545455, 2.2727272727, 2.5000000000, + 2.7272727273, 2.9545454545, 3.1818181818, 3.4090909091, + 3.6363636364, 3.8636363636, 4.0909090909, 4.3181818182, + 4.5454545455, 4.7727272727, 5.0000000000, 5.2272727273, + 5.4545454545, 5.6818181818, 5.9090909091, 6.1363636364, + 6.3636363636, 6.5909090909, 6.8181818182, 7.0454545455, + 7.2727272727, 7.5000000000, 7.7272727273, 7.9545454545, + 8.1818181818, 8.4090909091, 8.6363636364, 8.8636363636, + 9.0909090909, 9.3181818182, 9.5454545455, 9.7727272727, + 10.0000000000, 10.2272727273, 10.4545454545, 10.6818181818, + 10.9090909091, 11.1363636364, 11.3636363636, 11.5909090909, + 11.8181818182, 12.0454545455, 12.2727272727, 12.5000000000, + 12.7272727273, 12.9545454545, 13.1818181818, 13.4090909091, + 13.6363636364, 13.8636363636, 14.0909090909, 14.3181818182, + 14.5454545455, 14.7727272727, 15.0000000000, 15.2272727273, + 15.4545454545, 15.6818181818, 15.9090909091, 16.1363636364, + 16.3636363636, 16.5909090909, 16.8181818182, 17.0454545455, + 17.2727272727, 17.5000000000, 17.7272727273, 17.9545454545, + 18.1818181818, 18.4090909091, 18.6363636364, 18.8636363636, + 19.0909090909, 19.3181818182, 19.5454545455, 19.7727272727, + 20.0000000000}); m.Invoke(); EXPECT_THAT( m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {0.0000453979, 0.0000569815, 0.0000715205, 0.0000897689, 0.0001126729, - 0.0001414198, 0.0001774998, 0.0002227827, 0.0002796147, 0.0003509396, - 0.0004404502, 0.0005527786, 0.0006937345, 0.0008706021, 0.0010925128, - 0.0013709094, 0.0017201256, 0.0021581065, 0.0027073042, 0.0033957870, - 0.0042586071, 0.0053394826, 0.0066928509, 0.0083863576, 0.0105038445, - 0.0131488902, 0.0164489307, 0.0205599431, 0.0256715863, 0.0320125562, - 0.0398556989, 0.0495221198, 0.0613831074, 0.0758581800, 0.0934070047, - 0.1145124805, 0.1396521834, 0.1692560327, 0.2036499335, 0.2429886272, - 0.2871859014, 0.3358556241, 0.3882805886, 0.4434251301, 0.5000000000, - 0.5565748699, 0.6117194114, 0.6641443759, 0.7128140986, 0.7570113728, - 0.7963500665, 0.8307439673, 0.8603478166, 0.8854875195, 0.9065929953, - 0.9241418200, 0.9386168926, 0.9504778802, 0.9601443011, 0.9679874438, - 0.9743284137, 0.9794400569, 0.9835510693, 0.9868511098, 0.9894961555, - 0.9916136424, 0.9933071491, 0.9946605174, 0.9957413929, 0.9966042130, - 0.9972926958, 0.9978418935, 0.9982798744, 0.9986290906, 0.9989074872, - 0.9991293979, 0.9993062655, 0.9994472214, 0.9995595498, 0.9996490604, - 0.9997203853, 0.9997772173, 0.9998225002, 0.9998585802, 0.9998873271, - 0.9999102311, 0.9999284795, 0.9999430185, 0.9999546021}, + {0.0000000021, 0.0000000026, 0.0000000032, 0.0000000041, 0.0000000051, + 0.0000000064, 0.0000000081, 0.0000000101, 0.0000000127, 0.0000000159, + 0.0000000200, 0.0000000251, 0.0000000315, 0.0000000396, 0.0000000497, + 0.0000000623, 0.0000000782, 0.0000000982, 0.0000001232, 0.0000001547, + 0.0000001942, 0.0000002437, 0.0000003059, 0.0000003840, 0.0000004819, + 0.0000006049, 0.0000007593, 0.0000009530, 0.0000011962, 0.0000015014, + 0.0000018846, 0.0000023654, 0.0000029690, 0.0000037266, 0.0000046776, + 0.0000058711, 0.0000073693, 0.0000092497, 0.0000116100, 0.0000145724, + 0.0000182909, 0.0000229581, 0.0000288162, 0.0000361690, 0.0000453979, + 0.0000569815, 0.0000715205, 0.0000897689, 0.0001126729, 0.0001414198, + 0.0001774998, 0.0002227827, 0.0002796147, 0.0003509396, 0.0004404502, + 0.0005527786, 0.0006937345, 0.0008706021, 0.0010925128, 0.0013709094, + 0.0017201256, 0.0021581065, 0.0027073042, 0.0033957870, 0.0042586071, + 0.0053394826, 0.0066928509, 0.0083863576, 0.0105038445, 0.0131488902, + 0.0164489307, 0.0205599431, 0.0256715863, 0.0320125562, 0.0398556989, + 0.0495221198, 0.0613831074, 0.0758581800, 0.0934070047, 0.1145124805, + 0.1396521834, 0.1692560327, 0.2036499335, 0.2429886272, 0.2871859014, + 0.3358556241, 0.3882805886, 0.4434251301, 0.5000000000, 0.5565748699, + 0.6117194114, 0.6641443759, 0.7128140986, 0.7570113728, 0.7963500665, + 0.8307439673, 0.8603478166, 0.8854875195, 0.9065929953, 0.9241418200, + 0.9386168926, 0.9504778802, 0.9601443011, 0.9679874438, 0.9743284137, + 0.9794400569, 0.9835510693, 0.9868511098, 0.9894961555, 0.9916136424, + 0.9933071491, 0.9946605174, 0.9957413929, 0.9966042130, 0.9972926958, + 0.9978418935, 0.9982798744, 0.9986290906, 0.9989074872, 0.9991293979, + 0.9993062655, 0.9994472214, 0.9995595498, 0.9996490604, 0.9997203853, + 0.9997772173, 0.9998225002, 0.9998585802, 0.9998873271, 0.9999102311, + 0.9999284795, 0.9999430185, 0.9999546021, 0.9999638310, 0.9999711838, + 0.9999770419, 0.9999817091, 0.9999854276, 0.9999883900, 0.9999907503, + 0.9999926307, 0.9999941289, 0.9999953224, 0.9999962734, 0.9999970310, + 0.9999976346, 0.9999981154, 0.9999984986, 0.9999988038, 0.9999990470, + 0.9999992407, 0.9999993951, 0.9999995181, 0.9999996160, 0.9999996941, + 0.9999997563, 0.9999998058, 0.9999998453, 0.9999998768, 0.9999999018, + 0.9999999218, 0.9999999377, 0.9999999503, 0.9999999604, 0.9999999685, + 0.9999999749, 0.9999999800, 0.9999999841, 0.9999999873, 0.9999999899, + 0.9999999919, 0.9999999936, 0.9999999949, 0.9999999959, 0.9999999968, + 0.9999999974, 0.9999999979}, kQuantizedToleranceInt16))); } diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index 1c73f06487b..b6e73c2d7a1 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -44,6 +44,7 @@ TfLiteRegistration* Register_CEIL(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_COS(); +TfLiteRegistration* Register_CUMSUM(); TfLiteRegistration* Register_DENSIFY(); TfLiteRegistration* Register_DEPTH_TO_SPACE(); TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); diff --git a/tensorflow/lite/kernels/cumsum.cc b/tensorflow/lite/kernels/cumsum.cc index 173de0959fa..b37bab15803 100644 --- a/tensorflow/lite/kernels/cumsum.cc +++ b/tensorflow/lite/kernels/cumsum.cc @@ -13,44 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -// TODO(b/161933288): Promote this op to builtin-op when we can add new builtin -// ops. namespace tflite { namespace ops { -namespace custom { +namespace builtin { namespace cumsum { -typedef struct { - bool exclusive; - bool reverse; -} TfLiteCumsumParams; - static const int kInputTensor = 0; static const int kAxisTensor = 1; static const int kOutputTensor = 0; -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* data = new TfLiteCumsumParams; - const uint8_t* buffer_data = reinterpret_cast(buffer); - - const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_data, length).AsMap(); - data->exclusive = m["exclusive"].AsBool(); - data->reverse = m["reverse"].AsBool(); - return data; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); -} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -58,8 +37,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* axis = GetInput(context, node, kAxisTensor); - TF_LITE_ENSURE(context, - input->type == kTfLiteInt32 || input->type == kTfLiteFloat32); + TF_LITE_ENSURE(context, input->type == kTfLiteInt32 || + input->type == kTfLiteFloat32 || + input->type == kTfLiteInt64); TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, NumElements(axis), 1); @@ -78,7 +58,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - auto* params = reinterpret_cast(node->user_data); + auto* params = reinterpret_cast(node->builtin_data); int axis = *GetTensorData(axis_tensor); if (axis < 0) axis += NumDimensions(input); @@ -95,6 +75,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(output)); break; } + case kTfLiteInt64: { + optimized_ops::CumSum(GetTensorData(input), + GetTensorShape(input), axis, params->exclusive, + params->reverse, GetTensorData(output)); + break; + } case kTfLiteFloat32: { optimized_ops::CumSum(GetTensorData(input), GetTensorShape(input), axis, params->exclusive, params->reverse, @@ -115,11 +101,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace cumsum TfLiteRegistration* Register_CUMSUM() { - static TfLiteRegistration r = {cumsum::Init, cumsum::Free, cumsum::Prepare, + static TfLiteRegistration r = {nullptr, nullptr, cumsum::Prepare, cumsum::Eval}; return &r; } -} // namespace custom +} // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/cumsum_test.cc b/tensorflow/lite/kernels/cumsum_test.cc index 092defdcba3..f11781dc744 100644 --- a/tensorflow/lite/kernels/cumsum_test.cc +++ b/tensorflow/lite/kernels/cumsum_test.cc @@ -17,18 +17,14 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/custom_ops_register.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/testing/util.h" namespace tflite { namespace ops { -namespace custom { - -TfLiteRegistration* Register_CUMSUM(); +namespace builtin { namespace { @@ -42,13 +38,8 @@ class CumsumOpModel : public SingleOpModel { output_ = AddOutput(output); - flexbuffers::Builder fbb; - fbb.Map([&]() { - fbb.Bool("exclusive", exclusive); - fbb.Bool("reverse", reverse); - }); - fbb.Finish(); - SetCustomOp("Cumsum", fbb.GetBuffer(), Register_CUMSUM); + SetBuiltinOp(BuiltinOperator_CUMSUM, BuiltinOptions_CumsumOptions, + CreateCumsumOptions(builder_, exclusive, reverse).Union()); BuildInterpreter({GetShape(input_), GetShape(axis_)}); } @@ -77,6 +68,23 @@ TEST(CumsumOpTest, SimpleIntTest) { testing::ElementsAreArray({1, 3, 6, 10, 5, 11, 18, 26})); } +TEST(CumsumOpTest, SimpleInt64Test) { + CumsumOpModel m({TensorType_INT64, {2, 4}}, {TensorType_INT64, {}}, + false, false); + + m.PopulateTensor( + m.input(), {100000000001l, 100000000002l, 100000000003l, 100000000004l, + 100000000005l, 100000000006l, 100000000007l, 100000000008l}); + m.PopulateTensor(m.axis(), {1}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray( + {100000000001l, 200000000003l, 300000000006l, + 400000000010l, 100000000005l, 200000000011l, + 300000000018l, 400000000026l})); +} + TEST(CumsumOpTest, SimpleIntAxis0Test) { CumsumOpModel m({TensorType_INT32, {2, 4}}, {TensorType_INT32, {}}, false, false); @@ -143,6 +151,6 @@ TEST(CumsumOpTest, SimpleFloatTest) { } } // namespace -} // namespace custom +} // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h index 8aadd379a43..a24c062dbe6 100644 --- a/tensorflow/lite/kernels/custom_ops_register.h +++ b/tensorflow/lite/kernels/custom_ops_register.h @@ -21,7 +21,6 @@ namespace tflite { namespace ops { namespace custom { -TfLiteRegistration* Register_CUMSUM(); TfLiteRegistration* Register_HASHTABLE(); TfLiteRegistration* Register_HASHTABLE_FIND(); TfLiteRegistration* Register_HASHTABLE_IMPORT(); diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h index 51f3d2559db..f9472515417 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_hybrid_3x3_filter.h @@ -288,13 +288,13 @@ struct DepthwiseConvHybridWindowPerChannel> 9; - uint32_t ua = sigmoid_table_uint16[uh]; - uint32_t ub = sigmoid_table_uint16[uh + 1]; - uint32_t ut = abs_input_data & 0x1ff; + // Define uh as uint32_t type not to make this function overflow. + uint32_t uh = abs_input_data >> 9; + uint32_t result; - // Interpolation is done using the fractional bit. - uint32_t result = (ua << 9) + ut * (ub - ua); + if (uh >= 255) { + // Saturate to maximum. + result = 0x7FFF << 10; + } else { + uint32_t ua = sigmoid_table_uint16[uh]; + uint32_t ub = sigmoid_table_uint16[uh + 1]; + uint32_t ut = abs_input_data & 0x1ff; + // Interpolation is done using the fractional bit. + result = (ua << 9) + ut * (ub - ua); + } result = (input_data >= 0) ? (result + (1 << 9)) : ((1 << (16 + 9)) - result + (1 << 9) - 1); diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h index d57b6f2c20e..a7c86ddbc71 100644 --- a/tensorflow/lite/kernels/internal/reference/reduce.h +++ b/tensorflow/lite/kernels/internal/reference/reduce.h @@ -381,8 +381,7 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point, const float scale = input_scale / output_scale; if (compute_sum) { // TODO(b/116341117): Eliminate float and do this completely in 8bit. - const float bias = - -input_zero_point * scale * num_elements_in_axis + 0.5f; + const float bias = -input_zero_point * scale * num_elements_in_axis; for (size_t idx = 0; idx < num_outputs; ++idx) { const U value = static_cast(TfLiteRound(temp_sum[idx] * scale + bias)) + @@ -390,7 +389,7 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point, output_data[idx] = static_cast(value); } } else { - const float bias = -input_zero_point * scale + 0.5f; + const float bias = -input_zero_point * scale; for (size_t idx = 0; idx < num_outputs; ++idx) { float float_mean = static_cast(temp_sum[idx]) / static_cast(num_elements_in_axis); diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 3eb26565bc2..26bccd3a4b6 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -2102,10 +2102,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { forget_layer_norm_coefficients, cell_layer_norm_coefficients, output_layer_norm_coefficients, input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias, projection_weights, - projection_bias, params, &op_data->integer_lstm_param, - output_state, cell_state, output, scratch0, scratch1, scratch2, - scratch3, scratch4, scratch5, - CpuBackendContext::GetFromContext(context)); + projection_bias, params, /*forward_sequence=*/true, + /*time_major=*/true, &op_data->integer_lstm_param, output_state, + cell_state, output, scratch0, scratch1, scratch2, scratch3, + scratch4, scratch5, CpuBackendContext::GetFromContext(context)); } else { TfLiteTensor* scratch0; TF_LITE_ENSURE_OK(context, diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 695100fa92f..aa9db64f057 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -1412,8 +1412,10 @@ inline void LstmStepInteger8x8_16( TFLITE_DCHECK(input_to_input_effective_bias); TFLITE_DCHECK(recurrent_to_input_effective_bias); } - TFLITE_DCHECK(projection_effective_bias); - + const bool use_projection = (projection_weight_ptr != nullptr); + if (use_projection) { + TFLITE_DCHECK(projection_effective_bias); + } if (!use_cifg) { // Calculate the input gate. (If not CIFG.) CalculateLstmGateInteger8x8_16( @@ -1479,7 +1481,7 @@ inline void LstmStepInteger8x8_16( quantized_proj_clip, output_state_ptr, context, scratch0, scratch4, scratch5); // Copy output state to the output. Note that unlike float or hybrid, output - // is always contigous. + // is always contiguous. std::copy_n(output_state_ptr, n_batch * n_output, output_ptr); } @@ -2177,7 +2179,7 @@ TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, const lstm_eval::IntegerLstmParameter* integer_lstm_param, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, @@ -2190,8 +2192,8 @@ TfLiteStatus EvalInteger8x8_16( max_time = 1; n_batch = input->dims->data[0]; } else { - max_time = input->dims->data[0]; - n_batch = input->dims->data[1]; + max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; + n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; } // n_cell and n_output will be the same size when there is no projection. @@ -2204,90 +2206,193 @@ TfLiteStatus EvalInteger8x8_16( // Get params for time/batch/sequence. const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; - const int input_step = n_batch * n_input; - const int output_step = n_batch * output_batch_leading_dim; - for (int t = 0; t < max_time; t++) { - const int t_rel = t; - int8_t* output_ptr = GetTensorData(output) + t_rel * output_step; - const int8_t* input_ptr = GetTensorData(input) + t_rel * input_step; - LstmStepInteger8x8_16( - input_ptr, GetTensorData(input_to_input_weights), - integer_lstm_param->effective_input_to_input_scale_a, - integer_lstm_param->effective_input_to_input_scale_b, - GetTensorData(input_to_forget_weights), - integer_lstm_param->effective_input_to_forget_scale_a, - integer_lstm_param->effective_input_to_forget_scale_b, - GetTensorData(input_to_cell_weights), - integer_lstm_param->effective_input_to_cell_scale_a, - integer_lstm_param->effective_input_to_cell_scale_b, - GetTensorData(input_to_output_weights), - integer_lstm_param->effective_input_to_output_scale_a, - integer_lstm_param->effective_input_to_output_scale_b, - GetTensorData(recurrent_to_input_weights), - integer_lstm_param->effective_recurrent_to_input_scale_a, - integer_lstm_param->effective_recurrent_to_input_scale_b, - GetTensorData(recurrent_to_forget_weights), - integer_lstm_param->effective_recurrent_to_forget_scale_a, - integer_lstm_param->effective_recurrent_to_forget_scale_b, - GetTensorData(recurrent_to_cell_weights), - integer_lstm_param->effective_recurrent_to_cell_scale_a, - integer_lstm_param->effective_recurrent_to_cell_scale_b, - GetTensorData(recurrent_to_output_weights), - integer_lstm_param->effective_recurrent_to_output_scale_a, - integer_lstm_param->effective_recurrent_to_output_scale_b, - GetTensorData(cell_to_input_weights), - integer_lstm_param->effective_cell_to_input_scale_a, - integer_lstm_param->effective_cell_to_input_scale_b, - GetTensorData(cell_to_forget_weights), - integer_lstm_param->effective_cell_to_forget_scale_a, - integer_lstm_param->effective_cell_to_forget_scale_b, - GetTensorData(cell_to_output_weights), - integer_lstm_param->effective_cell_to_output_scale_a, - integer_lstm_param->effective_cell_to_output_scale_b, - GetTensorData(projection_weights), - integer_lstm_param->effective_proj_scale_a, - integer_lstm_param->effective_proj_scale_b, - integer_lstm_param->hidden_zp, - integer_lstm_param->effective_hidden_scale_a, - integer_lstm_param->effective_hidden_scale_b, - GetTensorData(input_layer_norm_coefficients), - integer_lstm_param->layer_norm_input_scale_a, - integer_lstm_param->layer_norm_input_scale_b, - GetTensorData(forget_layer_norm_coefficients), - integer_lstm_param->layer_norm_forget_scale_a, - integer_lstm_param->layer_norm_forget_scale_b, - GetTensorData(cell_layer_norm_coefficients), - integer_lstm_param->layer_norm_cell_scale_a, - integer_lstm_param->layer_norm_cell_scale_b, - GetTensorData(output_layer_norm_coefficients), - integer_lstm_param->layer_norm_output_scale_a, - integer_lstm_param->layer_norm_output_scale_b, - GetTensorData(input_gate_bias), - GetTensorData(forget_gate_bias), - GetTensorData(cell_gate_bias), - GetTensorData(output_gate_bias), - integer_lstm_param->quantized_cell_clip, - integer_lstm_param->quantized_proj_clip, integer_lstm_param->cell_scale, - integer_lstm_param->input_variance_guard, - integer_lstm_param->forget_variance_guard, - integer_lstm_param->cell_variance_guard, - integer_lstm_param->output_variance_guard, - integer_lstm_param->input_to_forget_effective_bias.get(), - integer_lstm_param->recurrent_to_forget_effective_bias.get(), - integer_lstm_param->input_to_cell_effective_bias.get(), - integer_lstm_param->recurrent_to_cell_effective_bias.get(), - integer_lstm_param->input_to_output_effective_bias.get(), - integer_lstm_param->recurrent_to_output_effective_bias.get(), - integer_lstm_param->input_to_input_effective_bias.get(), - integer_lstm_param->recurrent_to_input_effective_bias.get(), - integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell, - n_input, n_output, GetTensorData(output_state), output_state_zp, - GetTensorData(cell_state), output_ptr, - GetTensorData(scratch0), GetTensorData(scratch1), - GetTensorData(scratch2), GetTensorData(scratch3), - GetTensorData(scratch4), GetTensorData(scratch5), - context); + if (time_major) { + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + const int t_rel = t; + int8_t* output_ptr = GetTensorData(output) + t_rel * output_step; + const int8_t* input_ptr = + GetTensorData(input) + t_rel * input_step; + LstmStepInteger8x8_16( + input_ptr, GetTensorData(input_to_input_weights), + integer_lstm_param->effective_input_to_input_scale_a, + integer_lstm_param->effective_input_to_input_scale_b, + GetTensorData(input_to_forget_weights), + integer_lstm_param->effective_input_to_forget_scale_a, + integer_lstm_param->effective_input_to_forget_scale_b, + GetTensorData(input_to_cell_weights), + integer_lstm_param->effective_input_to_cell_scale_a, + integer_lstm_param->effective_input_to_cell_scale_b, + GetTensorData(input_to_output_weights), + integer_lstm_param->effective_input_to_output_scale_a, + integer_lstm_param->effective_input_to_output_scale_b, + GetTensorData(recurrent_to_input_weights), + integer_lstm_param->effective_recurrent_to_input_scale_a, + integer_lstm_param->effective_recurrent_to_input_scale_b, + GetTensorData(recurrent_to_forget_weights), + integer_lstm_param->effective_recurrent_to_forget_scale_a, + integer_lstm_param->effective_recurrent_to_forget_scale_b, + GetTensorData(recurrent_to_cell_weights), + integer_lstm_param->effective_recurrent_to_cell_scale_a, + integer_lstm_param->effective_recurrent_to_cell_scale_b, + GetTensorData(recurrent_to_output_weights), + integer_lstm_param->effective_recurrent_to_output_scale_a, + integer_lstm_param->effective_recurrent_to_output_scale_b, + GetTensorData(cell_to_input_weights), + integer_lstm_param->effective_cell_to_input_scale_a, + integer_lstm_param->effective_cell_to_input_scale_b, + GetTensorData(cell_to_forget_weights), + integer_lstm_param->effective_cell_to_forget_scale_a, + integer_lstm_param->effective_cell_to_forget_scale_b, + GetTensorData(cell_to_output_weights), + integer_lstm_param->effective_cell_to_output_scale_a, + integer_lstm_param->effective_cell_to_output_scale_b, + GetTensorData(projection_weights), + integer_lstm_param->effective_proj_scale_a, + integer_lstm_param->effective_proj_scale_b, + integer_lstm_param->hidden_zp, + integer_lstm_param->effective_hidden_scale_a, + integer_lstm_param->effective_hidden_scale_b, + GetTensorData(input_layer_norm_coefficients), + integer_lstm_param->layer_norm_input_scale_a, + integer_lstm_param->layer_norm_input_scale_b, + GetTensorData(forget_layer_norm_coefficients), + integer_lstm_param->layer_norm_forget_scale_a, + integer_lstm_param->layer_norm_forget_scale_b, + GetTensorData(cell_layer_norm_coefficients), + integer_lstm_param->layer_norm_cell_scale_a, + integer_lstm_param->layer_norm_cell_scale_b, + GetTensorData(output_layer_norm_coefficients), + integer_lstm_param->layer_norm_output_scale_a, + integer_lstm_param->layer_norm_output_scale_b, + GetTensorData(input_gate_bias), + GetTensorData(forget_gate_bias), + GetTensorData(cell_gate_bias), + GetTensorData(output_gate_bias), + integer_lstm_param->quantized_cell_clip, + integer_lstm_param->quantized_proj_clip, + integer_lstm_param->cell_scale, + integer_lstm_param->input_variance_guard, + integer_lstm_param->forget_variance_guard, + integer_lstm_param->cell_variance_guard, + integer_lstm_param->output_variance_guard, + integer_lstm_param->input_to_forget_effective_bias.get(), + integer_lstm_param->recurrent_to_forget_effective_bias.get(), + integer_lstm_param->input_to_cell_effective_bias.get(), + integer_lstm_param->recurrent_to_cell_effective_bias.get(), + integer_lstm_param->input_to_output_effective_bias.get(), + integer_lstm_param->recurrent_to_output_effective_bias.get(), + integer_lstm_param->input_to_input_effective_bias.get(), + integer_lstm_param->recurrent_to_input_effective_bias.get(), + integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell, + n_input, n_output, GetTensorData(output_state), + output_state_zp, GetTensorData(cell_state), output_ptr, + GetTensorData(scratch0), GetTensorData(scratch1), + GetTensorData(scratch2), GetTensorData(scratch3), + GetTensorData(scratch4), GetTensorData(scratch5), + context); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const int time_offset = b * max_time + t_rel; + const int8_t* input_ptr = + GetTensorData(input) + time_offset * input_step; + int8_t* output_ptr = + GetTensorData(output) + time_offset * output_step; + + // Offset the {output,cell}_state pointers to the right batch. + int8_t* output_state_ptr = + GetTensorData(output_state) + b * output_batch_leading_dim; + int16_t* cell_state_ptr = + GetTensorData(cell_state) + b * n_cell; + + LstmStepInteger8x8_16( + input_ptr, GetTensorData(input_to_input_weights), + integer_lstm_param->effective_input_to_input_scale_a, + integer_lstm_param->effective_input_to_input_scale_b, + GetTensorData(input_to_forget_weights), + integer_lstm_param->effective_input_to_forget_scale_a, + integer_lstm_param->effective_input_to_forget_scale_b, + GetTensorData(input_to_cell_weights), + integer_lstm_param->effective_input_to_cell_scale_a, + integer_lstm_param->effective_input_to_cell_scale_b, + GetTensorData(input_to_output_weights), + integer_lstm_param->effective_input_to_output_scale_a, + integer_lstm_param->effective_input_to_output_scale_b, + GetTensorData(recurrent_to_input_weights), + integer_lstm_param->effective_recurrent_to_input_scale_a, + integer_lstm_param->effective_recurrent_to_input_scale_b, + GetTensorData(recurrent_to_forget_weights), + integer_lstm_param->effective_recurrent_to_forget_scale_a, + integer_lstm_param->effective_recurrent_to_forget_scale_b, + GetTensorData(recurrent_to_cell_weights), + integer_lstm_param->effective_recurrent_to_cell_scale_a, + integer_lstm_param->effective_recurrent_to_cell_scale_b, + GetTensorData(recurrent_to_output_weights), + integer_lstm_param->effective_recurrent_to_output_scale_a, + integer_lstm_param->effective_recurrent_to_output_scale_b, + GetTensorData(cell_to_input_weights), + integer_lstm_param->effective_cell_to_input_scale_a, + integer_lstm_param->effective_cell_to_input_scale_b, + GetTensorData(cell_to_forget_weights), + integer_lstm_param->effective_cell_to_forget_scale_a, + integer_lstm_param->effective_cell_to_forget_scale_b, + GetTensorData(cell_to_output_weights), + integer_lstm_param->effective_cell_to_output_scale_a, + integer_lstm_param->effective_cell_to_output_scale_b, + GetTensorData(projection_weights), + integer_lstm_param->effective_proj_scale_a, + integer_lstm_param->effective_proj_scale_b, + integer_lstm_param->hidden_zp, + integer_lstm_param->effective_hidden_scale_a, + integer_lstm_param->effective_hidden_scale_b, + GetTensorData(input_layer_norm_coefficients), + integer_lstm_param->layer_norm_input_scale_a, + integer_lstm_param->layer_norm_input_scale_b, + GetTensorData(forget_layer_norm_coefficients), + integer_lstm_param->layer_norm_forget_scale_a, + integer_lstm_param->layer_norm_forget_scale_b, + GetTensorData(cell_layer_norm_coefficients), + integer_lstm_param->layer_norm_cell_scale_a, + integer_lstm_param->layer_norm_cell_scale_b, + GetTensorData(output_layer_norm_coefficients), + integer_lstm_param->layer_norm_output_scale_a, + integer_lstm_param->layer_norm_output_scale_b, + GetTensorData(input_gate_bias), + GetTensorData(forget_gate_bias), + GetTensorData(cell_gate_bias), + GetTensorData(output_gate_bias), + integer_lstm_param->quantized_cell_clip, + integer_lstm_param->quantized_proj_clip, + integer_lstm_param->cell_scale, + integer_lstm_param->input_variance_guard, + integer_lstm_param->forget_variance_guard, + integer_lstm_param->cell_variance_guard, + integer_lstm_param->output_variance_guard, + integer_lstm_param->input_to_forget_effective_bias.get(), + integer_lstm_param->recurrent_to_forget_effective_bias.get(), + integer_lstm_param->input_to_cell_effective_bias.get(), + integer_lstm_param->recurrent_to_cell_effective_bias.get(), + integer_lstm_param->input_to_output_effective_bias.get(), + integer_lstm_param->recurrent_to_output_effective_bias.get(), + integer_lstm_param->input_to_input_effective_bias.get(), + integer_lstm_param->recurrent_to_input_effective_bias.get(), + integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1, + n_cell, n_input, n_output, output_state_ptr, output_state_zp, + cell_state_ptr, output_ptr, GetTensorData(scratch0), + GetTensorData(scratch1), GetTensorData(scratch2), + GetTensorData(scratch3), GetTensorData(scratch4), + GetTensorData(scratch5), context); + } + } } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index 5807c9ee56d..6e286626fb9 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -188,7 +188,7 @@ TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, const lstm_eval::IntegerLstmParameter* integer_lstm_param, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, diff --git a/tensorflow/lite/kernels/lstm_eval_test.cc b/tensorflow/lite/kernels/lstm_eval_test.cc index c7d935a4b4f..bdad6c790eb 100644 --- a/tensorflow/lite/kernels/lstm_eval_test.cc +++ b/tensorflow/lite/kernels/lstm_eval_test.cc @@ -617,8 +617,9 @@ void TestOneFullyQuantizedLSTM() { one_parameter.GetOutputLayerNorm(), one_parameter.GetInputBias(), one_parameter.GetForgetBias(), one_parameter.GetCellBias(), one_parameter.GetOutputBias(), one_parameter.GetProjection(), - one_parameter.GetProjectionBias(), nullptr, param, activation, cell, - output, one_parameter.GetScratch0(), one_parameter.GetScratch1(), + one_parameter.GetProjectionBias(), nullptr, /*forward_sequence=*/true, + /*time_major=*/true, param, activation, cell, output, + one_parameter.GetScratch0(), one_parameter.GetScratch1(), one_parameter.GetScratch2(), one_parameter.GetScratch3(), one_parameter.GetScratch4(), one_parameter.GetScratch5(), &context); diff --git a/tensorflow/lite/kernels/quantize.cc b/tensorflow/lite/kernels/quantize.cc index 8f396355777..8ddc18be2b1 100644 --- a/tensorflow/lite/kernels/quantize.cc +++ b/tensorflow/lite/kernels/quantize.cc @@ -120,8 +120,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { // Requantize use case. if (input->type == kTfLiteInt16) { - TF_LITE_ENSURE( - context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16); + TF_LITE_ENSURE(context, output->type == kTfLiteInt8 || + output->type == kTfLiteInt16 || + output->type == kTfLiteInt32); } else { TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || input->type == kTfLiteUInt8); @@ -198,6 +199,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output->params.zero_point, GetTensorData(output)); return kTfLiteOk; + case kTfLiteInt32: + // This case is not supported by the converter or other TFLite tools. + // The only use case is for applications that take quantized int32 + // inference outputs. + Requantize(GetTensorData(input), + MatchingFlatSize(input_shape, output_shape), + data->output_multiplier, data->output_shift, + input->params.zero_point, + output->params.zero_point, + GetTensorData(output)); + return kTfLiteOk; default: ReportError(context, input->type, output->type); return kTfLiteError; diff --git a/tensorflow/lite/kernels/quantize_test.cc b/tensorflow/lite/kernels/quantize_test.cc index d7392b3e3ea..a8d68f6875b 100644 --- a/tensorflow/lite/kernels/quantize_test.cc +++ b/tensorflow/lite/kernels/quantize_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include #include @@ -458,5 +459,59 @@ TEST(QuantizeOpTest, Int16Int8SmallerScaleNeonPath) { 19, 17, 15, 13, 11, 9, 7, 5, 3, 1})); } +// Input scale 1.0, output scale 1.0, input zeropoint 0, output zeropoint 0 +TEST(QuantizeOpTest, Int16Int32SameScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}); + + // Input will quantized to {1,3,5,7,9,11,13,15,17,19}. + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +// Input scale 0.500000, output scale 1.000000, input zeropoint -1, output +// zeropoint 0 +TEST(QuantizeOpTest, Int16Int32LargerScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min() / 2.0, + std::numeric_limits::max() / 2.0}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}); + + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output +// zeropoint 0 +TEST(QuantizeOpTest, Int16Int32SmallerScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min() / 2.0, + std::numeric_limits::max() / 2.0}); + + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 4, 6, 8, 10, 12, 14, 16, 18, 20})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 8d8095bdc82..d2bb6dfd632 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -294,6 +294,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(), /* min_version = */ 1, /* max_version = */ 3); + AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index ad513e9f918..64274812d7f 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -193,7 +193,10 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, UpdateOpVersion(buffer_pointer); if (!resolver_) { - auto resolver = new ops::builtin::BuiltinOpResolver(); + MutableOpResolver* resolver = + apply_delegate + ? new ops::builtin::BuiltinOpResolver() + : new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates(); for (const auto& reg : custom_registrations_) { resolver->AddCustom(reg.first.data(), reg.second()); } diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index f739827c5b3..9cd272f3030 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -485,6 +485,10 @@ class SingleOpModel { // Build the interpreter for this model. Also, resize and allocate all // tensors given the shapes of the inputs. + // Note: 'apply_delegate' also serves to tell whether default TfLite delegates + // should be applied implicitly for a test case. For example, when testing the + // specific implementation of a TfLite delegate, it might be necessary to set + // this to false. void BuildInterpreter(std::vector> input_shapes, int num_threads, bool allow_fp32_relax_to_fp16, bool apply_delegate, bool allocate_and_delegate = true); diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 240d7125a5f..d1a7c6ba9b2 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/lstm_eval.h" @@ -31,15 +33,350 @@ namespace tflite { namespace ops { namespace builtin { namespace unidirectional_sequence_lstm { +namespace { struct OpData { // If the lstm is layer norm. - bool is_layer_norm_lstm; + bool use_layer_norm; // The scratch tensor index. int scratch_tensor_index; bool compute_row_sums = false; + + lstm_eval::IntegerLstmParameter integer_lstm_param; }; +TfLiteStatus PopulateQuantizedLstmParams8x8_16( + TfLiteContext* context, TfLiteNode* node, + lstm_eval::IntegerLstmParameter* integer_lstm_param) { + // Calculate quantized clip for projection and cell. + const auto* params = + static_cast(node->builtin_data); + const float cell_clip = params->cell_clip; + const float proj_clip = params->proj_clip; + + const TfLiteTensor* cell_state = + GetVariableInput(context, node, lstm::full::kCellStateTensor); + TF_LITE_ENSURE(context, cell_state != nullptr); + TfLiteTensor* output_tensor; + TF_LITE_ENSURE_OK( + context, + GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor)); + + auto* cell_state_params = + static_cast(cell_state->quantization.params); + auto* proj_params = static_cast( + output_tensor->quantization.params); + if (cell_clip > 0.0) { + integer_lstm_param->quantized_cell_clip = static_cast(std::min( + std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f), + 32767.0f)); + } else { + integer_lstm_param->quantized_cell_clip = 0; + } + if (proj_clip > 0.0) { + integer_lstm_param->quantized_proj_clip = static_cast(std::min( + std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f)); + } else { + integer_lstm_param->quantized_proj_clip = 0; + } + + // Calculate effective scales. + OpData* op_data = static_cast(node->user_data); + const bool use_layer_norm = op_data->use_layer_norm; + + const TfLiteTensor* input; + TF_LITE_ENSURE_OK( + context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); + + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( + context, node, lstm::full::kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, + &input_to_forget_weights)); + const TfLiteTensor* input_to_cell_weights; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, + lstm::full::kInputToCellWeightsTensor, + &input_to_cell_weights)); + const TfLiteTensor* input_to_output_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, + &input_to_output_weights)); + + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( + context, node, lstm::full::kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, + &recurrent_to_forget_weights)); + const TfLiteTensor* recurrent_to_cell_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, + &recurrent_to_cell_weights)); + const TfLiteTensor* recurrent_to_output_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, + &recurrent_to_output_weights)); + + const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor( + context, node, lstm::full::kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor( + context, node, lstm::full::kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor( + context, node, lstm::full::kCellToOutputWeightsTensor); + + const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( + context, node, lstm::full::kInputLayerNormCoefficientsTensor); + const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( + context, node, lstm::full::kForgetLayerNormCoefficientsTensor); + const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor( + context, node, lstm::full::kCellLayerNormCoefficientsTensor); + const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor( + context, node, lstm::full::kOutputLayerNormCoefficientsTensor); + + const TfLiteTensor* projection_weights = GetOptionalInputTensor( + context, node, lstm::full::kProjectionWeightsTensor); + + TfLiteTensor* output_state = + GetVariableInput(context, node, lstm::full::kOutputStateTensor); + TF_LITE_ENSURE(context, output_state != nullptr); + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + const bool use_projection = (projection_weights != nullptr); + + // Get intermediate scales and zero points. + std::vector intermediate_scale; + std::vector intermediate_zp; + for (int i = 0; i < 4; ++i) { + if (use_layer_norm) { + TfLiteTensor* intermediate; + TF_LITE_ENSURE_OK(context, + GetIntermediatesSafe(context, node, i, &intermediate)); + auto* params = static_cast( + intermediate->quantization.params); + intermediate_scale.push_back(params->scale->data[0]); + intermediate_zp.push_back(params->zero_point->data[0]); + } else { + // Q3.12 for activation functions. + intermediate_scale.push_back(std::pow(2, -12)); + intermediate_zp.push_back(0); + } + } + // In the absense of projection, hidden becomes otuput and this intermediate + // is ignored. + TfLiteTensor* hidden; + TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden)); + auto* hidden_params = + static_cast(hidden->quantization.params); + intermediate_scale.push_back(hidden_params->scale->data[0]); + intermediate_zp.push_back(hidden_params->zero_point->data[0]); + + // Scales. + const float default_scale = 1.0; + float input_scale = default_scale; + float input_to_input_weight_scale = default_scale; + float recurrent_to_input_weight_scale = default_scale; + float cell_to_input_weight_scale = default_scale; + float input_to_forget_weight_scale = default_scale; + float recurrent_to_forget_weight_scale = default_scale; + float cell_to_forget_weight_scale = default_scale; + float input_to_cell_weight_scale = default_scale; + float recurrent_to_cell_weight_scale = default_scale; + float input_to_output_weight_scale = default_scale; + float recurrent_to_output_weight_scale = default_scale; + float cell_to_output_weight_scale = default_scale; + float projection_weight_scale = default_scale; + float layer_norm_input_scale = default_scale; + float layer_norm_forget_scale = default_scale; + float layer_norm_cell_scale = default_scale; + float layer_norm_output_scale = default_scale; + float output_state_scale = default_scale; + int cell_scale = 1; + + // Effective scales. + float effective_input_to_input_scale = default_scale; + float effective_recurrent_to_input_scale = default_scale; + float effective_cell_to_input_scale = default_scale; + float effective_input_to_forget_scale = default_scale; + float effective_recurrent_to_forget_scale = default_scale; + float effective_cell_to_forget_scale = default_scale; + float effective_input_to_cell_scale = default_scale; + float effective_recurrent_to_cell_scale = default_scale; + float effective_input_to_output_scale = default_scale; + float effective_recurrent_to_output_scale = default_scale; + float effective_cell_to_output_scale = default_scale; + float effective_proj_scale = default_scale; + float effective_hidden_scale = default_scale; + + // Populate scales. + if (!use_cifg) { + input_to_input_weight_scale = input_to_input_weights->params.scale; + recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale; + } + + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weight_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weight_scale = cell_to_forget_weights->params.scale; + cell_to_output_weight_scale = cell_to_output_weights->params.scale; + } + + if (use_layer_norm) { + if (!use_cifg) { + layer_norm_input_scale = input_layer_norm_coefficients->params.scale; + } + layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale; + layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale; + layer_norm_output_scale = output_layer_norm_coefficients->params.scale; + } + + if (use_projection) { + projection_weight_scale = projection_weights->params.scale; + } + output_state_scale = output_state->params.scale; + + input_to_forget_weight_scale = input_to_forget_weights->params.scale; + input_to_cell_weight_scale = input_to_cell_weights->params.scale; + input_to_output_weight_scale = input_to_output_weights->params.scale; + recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale; + recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale; + recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale; + + // Check cell state (already used above) + TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale)); + // TF_LITE_ENSURE(context, cell_scale <= -9); + integer_lstm_param->cell_scale = cell_scale; + input_scale = input->params.scale; + + // Calculate effective scales. + if (!use_cifg) { + effective_input_to_input_scale = + input_to_input_weight_scale * input_scale / intermediate_scale[0]; + effective_recurrent_to_input_scale = recurrent_to_input_weight_scale * + output_state_scale / + intermediate_scale[0]; + } + effective_input_to_forget_scale = + input_to_forget_weight_scale * input_scale / intermediate_scale[1]; + effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale * + output_state_scale / + intermediate_scale[1]; + + effective_input_to_cell_scale = + input_to_cell_weight_scale * input_scale / intermediate_scale[2]; + effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale * + output_state_scale / + intermediate_scale[2]; + + effective_input_to_output_scale = + input_to_output_weight_scale * input_scale / intermediate_scale[3]; + effective_recurrent_to_output_scale = recurrent_to_output_weight_scale * + output_state_scale / + intermediate_scale[3]; + + effective_hidden_scale = + std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15); + + effective_proj_scale = + projection_weight_scale * intermediate_scale[4] / output_state_scale; + + if (use_peephole) { + if (!use_cifg) { + effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT + cell_to_input_weight_scale / + intermediate_scale[0]; + } + effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT + cell_to_forget_weight_scale / + intermediate_scale[1]; + effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT + cell_to_output_weight_scale / + intermediate_scale[3]; + } + + // Decompose scales. + QuantizeMultiplier(effective_input_to_input_scale, + &integer_lstm_param->effective_input_to_input_scale_a, + &integer_lstm_param->effective_input_to_input_scale_b); + QuantizeMultiplier(effective_recurrent_to_input_scale, + &integer_lstm_param->effective_recurrent_to_input_scale_a, + &integer_lstm_param->effective_recurrent_to_input_scale_b); + QuantizeMultiplier(effective_cell_to_input_scale, + &integer_lstm_param->effective_cell_to_input_scale_a, + &integer_lstm_param->effective_cell_to_input_scale_b); + QuantizeMultiplier(effective_input_to_forget_scale, + &integer_lstm_param->effective_input_to_forget_scale_a, + &integer_lstm_param->effective_input_to_forget_scale_b); + QuantizeMultiplier( + effective_recurrent_to_forget_scale, + &integer_lstm_param->effective_recurrent_to_forget_scale_a, + &integer_lstm_param->effective_recurrent_to_forget_scale_b); + QuantizeMultiplier(effective_cell_to_forget_scale, + &integer_lstm_param->effective_cell_to_forget_scale_a, + &integer_lstm_param->effective_cell_to_forget_scale_b); + QuantizeMultiplier(effective_input_to_cell_scale, + &integer_lstm_param->effective_input_to_cell_scale_a, + &integer_lstm_param->effective_input_to_cell_scale_b); + QuantizeMultiplier(effective_recurrent_to_cell_scale, + &integer_lstm_param->effective_recurrent_to_cell_scale_a, + &integer_lstm_param->effective_recurrent_to_cell_scale_b); + QuantizeMultiplier(effective_input_to_output_scale, + &integer_lstm_param->effective_input_to_output_scale_a, + &integer_lstm_param->effective_input_to_output_scale_b); + QuantizeMultiplier( + effective_recurrent_to_output_scale, + &integer_lstm_param->effective_recurrent_to_output_scale_a, + &integer_lstm_param->effective_recurrent_to_output_scale_b); + QuantizeMultiplier(effective_cell_to_output_scale, + &integer_lstm_param->effective_cell_to_output_scale_a, + &integer_lstm_param->effective_cell_to_output_scale_b); + QuantizeMultiplier(effective_proj_scale, + &integer_lstm_param->effective_proj_scale_a, + &integer_lstm_param->effective_proj_scale_b); + QuantizeMultiplier(effective_hidden_scale, + &integer_lstm_param->effective_hidden_scale_a, + &integer_lstm_param->effective_hidden_scale_b); + QuantizeMultiplier(layer_norm_input_scale, + &integer_lstm_param->layer_norm_input_scale_a, + &integer_lstm_param->layer_norm_input_scale_b); + QuantizeMultiplier(layer_norm_forget_scale, + &integer_lstm_param->layer_norm_forget_scale_a, + &integer_lstm_param->layer_norm_forget_scale_b); + QuantizeMultiplier(layer_norm_cell_scale, + &integer_lstm_param->layer_norm_cell_scale_a, + &integer_lstm_param->layer_norm_cell_scale_b); + QuantizeMultiplier(layer_norm_output_scale, + &integer_lstm_param->layer_norm_output_scale_a, + &integer_lstm_param->layer_norm_output_scale_b); + + integer_lstm_param->hidden_zp = intermediate_zp[4]; + + // 10000 is used to make sure the kernel logic does not overflow. + if (!use_cifg) { + integer_lstm_param->input_variance_guard = + std::max(1, static_cast(10000 * layer_norm_input_scale)); + } + integer_lstm_param->forget_variance_guard = + std::max(1, static_cast(10000 * layer_norm_forget_scale)); + integer_lstm_param->cell_variance_guard = + std::max(1, static_cast(10000 * layer_norm_cell_scale)); + integer_lstm_param->output_variance_guard = + std::max(1, static_cast(10000 * layer_norm_output_scale)); + + return kTfLiteOk; +} + +} // namespace + // Temporary tensors enum TemporaryTensor { kScratchBuffer = 0, @@ -72,7 +409,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell, - bool is_layer_norm_lstm) { + bool use_layer_norm, bool is_integer) { const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. @@ -151,6 +488,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, if (cell_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_input_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); } const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor( @@ -158,6 +498,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, if (cell_to_forget_weights != nullptr) { TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_forget_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); } const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor( @@ -165,6 +508,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, if (cell_to_output_weights != nullptr) { TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_TYPES_EQ( + context, cell_to_output_weights->type, + is_integer ? kTfLiteInt16 : input_to_forget_weights->type); } // Making sure the peephole weights are there all or none. @@ -186,6 +532,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, } else { TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); + } } const TfLiteTensor* forget_gate_bias; @@ -194,6 +545,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, &forget_gate_bias)); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); + } const TfLiteTensor* cell_gate_bias; TF_LITE_ENSURE_OK(context, @@ -201,6 +557,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, &cell_gate_bias)); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32); + } const TfLiteTensor* output_gate_bias; TF_LITE_ENSURE_OK( @@ -208,6 +569,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, &output_gate_bias)); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); + } const TfLiteTensor* projection_weights = GetOptionalInputTensor( context, node, lstm::full::kProjectionWeightsTensor); @@ -222,6 +588,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32); + } } // Making sure the projection tensors are consistent: @@ -233,7 +604,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, ((projection_weights != nullptr) || (projection_bias == nullptr)); TF_LITE_ENSURE(context, projecton_tensors_consistent == true); - if (is_layer_norm_lstm) { + if (use_layer_norm) { const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, lstm::full::kInputLayerNormCoefficientsTensor); if (use_cifg) { @@ -243,8 +614,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], n_cell); - TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, - kTfLiteFloat32); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, + kTfLiteFloat32); + } } const TfLiteTensor* forget_layer_norm_coefficients; @@ -255,8 +631,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], n_cell); - TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, - kTfLiteFloat32); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, + kTfLiteFloat32); + } const TfLiteTensor* cell_layer_norm_coefficients; TF_LITE_ENSURE_OK(context, @@ -266,8 +647,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], n_cell); - TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, - kTfLiteFloat32); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, + kTfLiteFloat32); + } const TfLiteTensor* output_layer_norm_coefficients; TF_LITE_ENSURE_OK( @@ -277,13 +663,185 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], n_cell); - TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, - kTfLiteFloat32); + if (is_integer) { + TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, + kTfLiteInt16); + } else { + TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, + kTfLiteFloat32); + } } return kTfLiteOk; } +TfLiteStatus PrecomputeZeroPointTimesWeightWithBias( + TfLiteContext* context, int32_t zero_point, + const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor, + std::unique_ptr* output) { + if (weight_tensor == nullptr) { + return kTfLiteOk; + } + + const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); + TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); + const int row = weight_shape.Dims(0); + const int col = weight_shape.Dims(1); + output->reset(new int32_t[row]); + if (bias_tensor == nullptr) { + memset(output->get(), 0, row * sizeof(int32_t)); + } else { + const int32_t* bias = GetTensorData(bias_tensor); + memcpy(output->get(), bias, row * sizeof(int32_t)); + } + if (zero_point != 0) { + const int8_t* weight = GetTensorData(weight_tensor); + tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col, + output->get()); + } + return kTfLiteOk; +} + +TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context, + OpData* op_data, + TfLiteNode* node) { + const TfLiteTensor* input; + TF_LITE_ENSURE_OK( + context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); + const TfLiteTensor* output_state = + GetVariableInput(context, node, lstm::full::kOutputStateTensor); + TF_LITE_ENSURE(context, output_state != nullptr); + + const int32_t input_zero_point = -input->params.zero_point; + const int32_t output_state_zero_point = -output_state->params.zero_point; + + const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( + context, node, lstm::full::kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, + &input_to_forget_weights)); + const TfLiteTensor* input_to_cell_weights; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, + lstm::full::kInputToCellWeightsTensor, + &input_to_cell_weights)); + const TfLiteTensor* input_to_output_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, + &input_to_output_weights)); + + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( + context, node, lstm::full::kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, + &recurrent_to_forget_weights)); + const TfLiteTensor* recurrent_to_cell_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, + &recurrent_to_cell_weights)); + const TfLiteTensor* recurrent_to_output_weights; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, + &recurrent_to_output_weights)); + + const TfLiteTensor* projection_weights = GetOptionalInputTensor( + context, node, lstm::full::kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); + + lstm_eval::IntegerLstmParameter* integer_lstm_params = + &op_data->integer_lstm_param; + + const TfLiteTensor* intermediate = + &context->tensors[node->intermediates->data[4]]; + const auto* params = + static_cast(intermediate->quantization.params); + const int32_t hidden_zp = params->zero_point->data[0]; + + // Get bias and perform zero point calculation. + // When there is layer normalization, the gate bias does not apply to matmul + // directly: + // y = ln(w * x + w * r + w * c) + b. + const bool is_layer_norm = op_data->use_layer_norm; + + // Forget gate. + const TfLiteTensor* forget_gate_bias = + is_layer_norm + ? nullptr + : GetInput(context, node, lstm::full::kForgetGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_forget_weights, forget_gate_bias, + &(integer_lstm_params->input_to_forget_effective_bias))); + + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_forget_weights, + nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias))); + + // Modulation gate. + const TfLiteTensor* cell_gate_bias = + is_layer_norm ? nullptr + : GetInput(context, node, lstm::full::kCellGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_cell_weights, cell_gate_bias, + &(integer_lstm_params->input_to_cell_effective_bias))); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_cell_weights, nullptr, + &(integer_lstm_params->recurrent_to_cell_effective_bias))); + + // Output gate. + const TfLiteTensor* output_gate_bias = + is_layer_norm + ? nullptr + : GetInput(context, node, lstm::full::kOutputGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_output_weights, output_gate_bias, + &(integer_lstm_params->input_to_output_effective_bias))); + + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_output_weights, + nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias))); + + // Input gate. The calculation is only meaningful for non-cifg case. + const TfLiteTensor* input_gate_bias = + is_layer_norm ? nullptr + : GetInput(context, node, lstm::full::kInputGateBiasTensor); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, input_zero_point, input_to_input_weights, input_gate_bias, + &(integer_lstm_params->input_to_input_effective_bias))); + TF_LITE_ENSURE_OK( + context, + PrecomputeZeroPointTimesWeightWithBias( + context, output_state_zero_point, recurrent_to_input_weights, nullptr, + &(integer_lstm_params->recurrent_to_input_effective_bias))); + + // Projection bias. The calculation is only meaningful for with projection. + TF_LITE_ENSURE_OK(context, + PrecomputeZeroPointTimesWeightWithBias( + context, hidden_zp, projection_weights, projection_bias, + &(integer_lstm_params->projection_effective_bias))); + return kTfLiteOk; +} + // Resize the output and state tensors based on the sizes of the input tensors. // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. @@ -292,18 +850,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. - bool is_layer_norm_lstm = false; + bool use_layer_norm = false; if (node->inputs->size == 24) { const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( context, node, lstm::full::kForgetLayerNormCoefficientsTensor); if (forget_layer_norm_coefficients == nullptr) { - is_layer_norm_lstm = false; + use_layer_norm = false; } else { - is_layer_norm_lstm = true; + use_layer_norm = true; } } else if (node->inputs->size == 20) { // This is deprecated and is only kept here for backward compatibility. - is_layer_norm_lstm = false; + use_layer_norm = false; } else { context->ReportError( context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", @@ -311,14 +869,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - op_data->is_layer_norm_lstm = is_layer_norm_lstm; + op_data->use_layer_norm = use_layer_norm; // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input; TF_LITE_ENSURE_OK( context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); - TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); + const bool is_integer = input->type == kTfLiteInt8; TF_LITE_ENSURE(context, input->dims->size > 1); const auto* params = reinterpret_cast( @@ -347,9 +905,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - TF_LITE_ENSURE_OK(context, - CheckInputTensorDimensions(context, node, n_input, n_output, - n_cell, is_layer_norm_lstm)); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_output, + n_cell, use_layer_norm, is_integer)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output; @@ -375,9 +933,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); + if (is_integer) { + const int num_intermediate_tensors = node->intermediates->size; + TF_LITE_ENSURE(context, num_intermediate_tensors == 5); + } + TfLiteIntArrayFree(node->temporaries); if (IsHybridOp(input, input_to_output_weights)) { node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else if (is_integer) { + node->temporaries = TfLiteIntArrayCreate(6); } else { node->temporaries = TfLiteIntArrayCreate(1); } @@ -590,6 +1155,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, context->ResizeTensor(context, row_sums, row_sums_size)); } } + + if (is_integer) { + // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16. + // This code path needs 5 intermediate tensors per Op. + // Populate quantization parameters. + PopulateQuantizedLstmParams8x8_16(context, node, + &op_data->integer_lstm_param); + // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell + // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit + // buffer with size n_batch * n_cell. + // + // Handle cifg case as well, which might save one buffer. + for (int scratch_index = 0; scratch_index < 6; ++scratch_index) { + node->temporaries->data[scratch_index] = + op_data->scratch_tensor_index + scratch_index; + TfLiteTensor* scratch_tensor; + TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index, + &scratch_tensor)); + + scratch_tensor->type = kTfLiteInt16; + if (scratch_index == 4) { + scratch_tensor->type = kTfLiteInt8; + } else if (scratch_index == 5) { + scratch_tensor->type = kTfLiteInt32; + } + + scratch_tensor->allocation_type = kTfLiteArenaRw; + const int scratch_dimension[2] = {n_batch, n_cell}; + if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2, + scratch_dimension)) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + scratch_buffer_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, scratch_tensor, + scratch_buffer_size)); + } + } + + // Populate precomputed zp * weight. + TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias( + context, op_data, node)); + } + return kTfLiteOk; } @@ -598,7 +1207,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { reinterpret_cast( node->builtin_data); const OpData* op_data = reinterpret_cast(node->user_data); - const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; + const bool use_layer_norm = op_data->use_layer_norm; const bool time_major = params->time_major; const TfLiteTensor* input; TF_LITE_ENSURE_OK( @@ -666,11 +1275,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer; - TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer, - &scratch_buffer)); - TfLiteTensor* output_state = GetVariableInput(context, node, lstm::full::kOutputStateTensor); TFLITE_DCHECK(output_state != nullptr); @@ -679,25 +1283,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(cell_state != nullptr); const TfLiteTensor* input_layer_norm_coefficients = - is_layer_norm_lstm + use_layer_norm ? GetOptionalInputTensor( context, node, lstm::full::kInputLayerNormCoefficientsTensor) : nullptr; const TfLiteTensor* forget_layer_norm_coefficients = - is_layer_norm_lstm - ? GetInput(context, node, - lstm::full::kForgetLayerNormCoefficientsTensor) - : nullptr; + use_layer_norm ? GetInput(context, node, + lstm::full::kForgetLayerNormCoefficientsTensor) + : nullptr; const TfLiteTensor* cell_layer_norm_coefficients = - is_layer_norm_lstm - ? GetInput(context, node, - lstm::full::kCellLayerNormCoefficientsTensor) - : nullptr; + use_layer_norm ? GetInput(context, node, + lstm::full::kCellLayerNormCoefficientsTensor) + : nullptr; const TfLiteTensor* output_layer_norm_coefficients = - is_layer_norm_lstm - ? GetInput(context, node, - lstm::full::kOutputLayerNormCoefficientsTensor) - : nullptr; + use_layer_norm ? GetInput(context, node, + lstm::full::kOutputLayerNormCoefficientsTensor) + : nullptr; TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, @@ -712,6 +1313,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (input_to_output_weights->type) { case kTfLiteFloat32: { + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer; + TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer, + &scratch_buffer)); return lstm_eval::EvalFloat( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -733,53 +1338,96 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteUInt8: case kTfLiteInt8: { - OpData* op_data = reinterpret_cast(node->user_data); - TfLiteTensor* row_sums; - TF_LITE_ENSURE_OK(context, - GetTemporarySafe(context, node, kRowSums, &row_sums)); - const int row_sums_size = row_sums->dims->data[0]; - return lstm_eval::EvalHybrid( - input, input_to_input_weights, - /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, - /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, - /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, - /*input_to_output_weights_ledger*/ nullptr, - recurrent_to_input_weights, - /*recurrent_to_input_weights_ledger*/ nullptr, - recurrent_to_forget_weights, - /*recurrent_to_forget_weights_ledger*/ nullptr, - recurrent_to_cell_weights, - /*recurrent_to_cell_weights_ledger*/ nullptr, - recurrent_to_output_weights, - /*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights, - cell_to_forget_weights, cell_to_output_weights, - input_layer_norm_coefficients, forget_layer_norm_coefficients, - cell_layer_norm_coefficients, output_layer_norm_coefficients, - /*aux_input=*/nullptr, - /*aux_input_to_input_weights=*/nullptr, - /*aux_input_to_forget_weights=*/nullptr, - /*aux_input_to_cell_weights=*/nullptr, - /*aux_input_to_output_weights=*/nullptr, input_gate_bias, - forget_gate_bias, cell_gate_bias, output_gate_bias, - projection_weights, /*projection_weights_ledger*/ nullptr, - projection_bias, &lstm_params, - /*forward_sequence=*/true, time_major, - /*output_offset=*/0, scratch_buffer, - GetTemporary(context, node, kInputScalingFactors), - /*aux_input_sf=*/nullptr, - GetTemporary(context, node, kOutputStateScalingFactors), - GetTemporary(context, node, kProductScalingFactors), - GetTemporary(context, node, kRecoveredCellWeights), - GetTemporary(context, node, kInputQuantized), - /*aux_input_quantized=*/nullptr, - GetTemporary(context, node, kOutputStateQuantized), - GetTemporary(context, node, kCellStateQuantized), output_state, - cell_state, GetTemporary(context, node, kAccumScratch), output, - GetTemporary(context, node, kInputZeroPoints), - /*aux_input_zp=*/nullptr, - GetTemporary(context, node, kOutputStateZeroPoints), row_sums, - row_sums_size, &op_data->compute_row_sums, - CpuBackendContext::GetFromContext(context)); + const bool is_hybrid = input->type == kTfLiteFloat32; + if (is_hybrid) { + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer; + TF_LITE_ENSURE_OK( + context, + GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer)); + + OpData* op_data = reinterpret_cast(node->user_data); + TfLiteTensor* row_sums; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, kRowSums, &row_sums)); + const int row_sums_size = row_sums->dims->data[0]; + return lstm_eval::EvalHybrid( + input, input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_coefficients, + forget_layer_norm_coefficients, cell_layer_norm_coefficients, + output_layer_norm_coefficients, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_gate_bias, output_gate_bias, + projection_weights, /*projection_weights_ledger*/ nullptr, + projection_bias, &lstm_params, + /*forward_sequence=*/true, time_major, + /*output_offset=*/0, scratch_buffer, + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), + GetTemporary(context, node, kProductScalingFactors), + GetTemporary(context, node, kRecoveredCellWeights), + GetTemporary(context, node, kInputQuantized), + /*aux_input_quantized=*/nullptr, + GetTemporary(context, node, kOutputStateQuantized), + GetTemporary(context, node, kCellStateQuantized), output_state, + cell_state, GetTemporary(context, node, kAccumScratch), output, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, + CpuBackendContext::GetFromContext(context)); + } else { + TfLiteTensor* scratch0; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 0, &scratch0)); + TfLiteTensor* scratch1; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 1, &scratch1)); + TfLiteTensor* scratch2; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 2, &scratch2)); + TfLiteTensor* scratch3; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 3, &scratch3)); + TfLiteTensor* scratch4; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 4, &scratch4)); + TfLiteTensor* scratch5; + TF_LITE_ENSURE_OK(context, + GetTemporarySafe(context, node, 5, &scratch5)); + return lstm_eval::EvalInteger8x8_16( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_coefficients, + forget_layer_norm_coefficients, cell_layer_norm_coefficients, + output_layer_norm_coefficients, input_gate_bias, forget_gate_bias, + cell_gate_bias, output_gate_bias, projection_weights, + projection_bias, &lstm_params, /*forward_sequence=*/true, + time_major, &op_data->integer_lstm_param, output_state, cell_state, + output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5, + CpuBackendContext::GetFromContext(context)); + } } default: TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc index 90a96ca98fe..94ed9f19352 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -2739,6 +2739,611 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +class UnidirectionalSequenceLSTMIntegerOpModel : public SingleOpModel { + public: + UnidirectionalSequenceLSTMIntegerOpModel( + int n_batch, int n_input, int n_cell, int n_output, int sequence_length, + bool time_major, bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + bool use_layer_norm, bool use_8x8_8_implementation, + const std::vector>& ranges, + const std::vector>& intermediates, + bool asymmetric_quantize_inputs = false) + : n_input_(n_input), n_output_(n_output) { + input_ = AddInput({TensorType_INT8, + {sequence_length, n_batch, n_input}, + ranges[0].first, + ranges[0].second}); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput({TensorType_INT8, + {n_cell, n_input}, + ranges[1].first, + ranges[1].second}); + } + input_to_forget_weights_ = AddInput({TensorType_INT8, + {n_cell, n_input}, + ranges[2].first, + ranges[2].second}); + input_to_cell_weights_ = AddInput({TensorType_INT8, + {n_cell, n_input}, + ranges[3].first, + ranges[3].second}); + input_to_output_weights_ = AddInput({TensorType_INT8, + {n_cell, n_input}, + ranges[4].first, + ranges[4].second}); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput({TensorType_INT8, + {n_cell, n_output}, + ranges[5].first, + ranges[5].second}); + } + recurrent_to_forget_weights_ = AddInput({TensorType_INT8, + {n_cell, n_output}, + ranges[6].first, + ranges[6].second}); + recurrent_to_cell_weights_ = AddInput({TensorType_INT8, + {n_cell, n_output}, + ranges[7].first, + ranges[7].second}); + recurrent_to_output_weights_ = AddInput({TensorType_INT8, + {n_cell, n_output}, + ranges[8].first, + ranges[8].second}); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[9].first, ranges[9].second}); + } + cell_to_forget_weights_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[10].first, ranges[10].second}); + cell_to_output_weights_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[11].first, ranges[11].second}); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput( + {TensorType_INT32, {n_cell}, ranges[12].first, ranges[12].second}); + } + forget_gate_bias_ = AddInput( + {TensorType_INT32, {n_cell}, ranges[13].first, ranges[13].second}); + cell_gate_bias_ = AddInput( + {TensorType_INT32, {n_cell}, ranges[14].first, ranges[14].second}); + output_gate_bias_ = AddInput( + {TensorType_INT32, {n_cell}, ranges[15].first, ranges[15].second}); + + if (use_projection_weights) { + projection_weights_ = AddInput({TensorType_INT8, + {n_output, n_cell}, + ranges[16].first, + ranges[16].second}); + } else { + projection_weights_ = AddNullInput(); + } + if (use_projection_bias) { + CHECK(use_projection_weights); + projection_bias_ = AddInput( + {TensorType_INT32, {n_output}, ranges[17].first, ranges[17].second}); + } else { + projection_bias_ = AddNullInput(); + } + + // Adding the 2 state tensors. + AddVariableInput({TensorType_INT16, + {n_batch, n_output}, + ranges[18].first, + ranges[18].second}); + AddVariableInput({TensorType_INT16, + {n_batch, n_cell}, + ranges[19].first, + ranges[19].second}); + + // Layer norm weights. + if (use_layer_norm) { + if (use_cifg) { + input_layer_norm_coefficients_ = AddNullInput(); + } else { + input_layer_norm_coefficients_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[20].first, ranges[20].second}); + } + forget_layer_norm_coefficients_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[21].first, ranges[21].second}); + cell_layer_norm_coefficients_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[22].first, ranges[22].second}); + output_layer_norm_coefficients_ = AddInput( + {TensorType_INT16, {n_cell}, ranges[23].first, ranges[23].second}); + } + + // use_8x8_8_implementation is not supported yet. + CHECK(!use_8x8_8_implementation); + EXPECT_EQ(intermediates.size(), 5); + + for (int i = 0; i < intermediates.size(); ++i) { + AddIntermediate(TensorType_INT16, {intermediates[i].first}, + {intermediates[i].second}); + } + + output_ = AddOutput({TensorType_INT8, + {n_batch, n_output}, + ranges[24].first, + ranges[24].second}); + + // TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the + // default 0. + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + CreateUnidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, /*cell_clip=*/0.0f, + /*proj_clip=*/0.0f, time_major, asymmetric_quantize_inputs) + .Union()); + + BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/true, /*allocate_and_delegate=*/false); + } + + void PerformAllocateAndDelegate() { AllocateAndDelegate(true); } + + void SetInputToInputWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(const std::vector& f) { + QuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(const std::vector& f) { + QuantizeAndPopulate(output_layer_norm_coefficients_, f); + } + + void SetInputGateBias(const std::vector& f) { + QuantizeAndPopulate(input_gate_bias_, f); + } + + void SetForgetGateBias(const std::vector& f) { + QuantizeAndPopulate(forget_gate_bias_, f); + } + + void SetCellBias(const std::vector& f) { + QuantizeAndPopulate(cell_gate_bias_, f); + } + + void SetOutputGateBias(const std::vector& f) { + QuantizeAndPopulate(output_gate_bias_, f); + } + + void SetProjectionWeights(const std::vector& f) { + QuantizeAndPopulate(projection_weights_, f); + } + + void SetProjectionBias(const std::vector& f) { + QuantizeAndPopulate(projection_bias_, f); + } + + void SetInput(const std::vector& f) { + QuantizeAndPopulate(input_, f); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_layer_norm_coefficients_; + int forget_layer_norm_coefficients_; + int cell_layer_norm_coefficients_; + int output_layer_norm_coefficients_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_gate_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + + int n_input_; + int n_output_; +}; + +TEST(IntegerUnidirectionalSequenceLstmOpTest, + NoCifg_NoPeephole_Projection_LayerNorm) { + // Hyper parameters. + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const int sequence_length = 3; + + // Model related weights. + const std::vector input_to_input_weights = { + 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5, + -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + const std::vector input_to_forget_weights = { + -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8, + -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + const std::vector input_to_cell_weights = { + -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6, + 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + const std::vector input_to_output_weights = { + -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2, + 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + const std::vector input_gate_bias = {0.03, 0.15, 0.22, 0.38}; + + const std::vector forget_gate_bias = {0.1, -0.3, -0.2, 0.1}; + + const std::vector cell_gate_bias = {-0.05, 0.72, 0.25, 0.08}; + + const std::vector output_gate_bias = {0.05, -0.01, 0.2, 0.1}; + + const std::vector recurrent_to_input_weights = { + -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + const std::vector recurrent_to_cell_weights = { + -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + const std::vector recurrent_to_forget_weights = { + -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + const std::vector recurrent_to_output_weights = { + 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + const std::vector input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5}; + const std::vector forget_layer_norm_coefficients = {0.2, 0.2, 0.4, + 0.3}; + const std::vector cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8}; + const std::vector output_layer_norm_coefficients = {0.6, 0.2, 0.2, + 0.5}; + + const std::vector projection_weights = { + -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + // Input ranges. + const std::vector> ranges = { + {-1.0, 127.0 / 128}, // input tensor + {-1.0, 1.0}, // input_to_input_weight tensor + {-1.0, 1.0}, // input_to_forget_weight tensor + {-1.0, 1.0}, // input_to_cell_weight tensor + {-1.0, 1.0}, // input_to_output_weight tensor + + {-1.0, 1.0}, // recurrent_to_input_weight tensor + {-1.0, 1.0}, // recurrent_to_forget_weight tensor + {-1.0, 1.0}, // recurrent_to_cell_weight tensor + {-1.0, 1.0}, // recurrent_to_output_weight tensor + + {-1, 1}, // cell_to_input_weight tensor + {-1, 1}, // cell_to_forget_weight tensor + {-1, 1}, // cell_to_output_weight tensor + + {-100, 100}, // input_gate_bias tensor + {-100, 100}, // forget_gate_bias tensor + {-100, 100}, // cell_gate_bias tensor + {-100, 100}, // output_gate_bias tensor + + {-0.5, 0.5}, // projection_weight tensor + {-1, 1}, // projection_bias tensor + + {-1.0, 32767.0 / 32768}, // output_state tensor + {-1, 1}, // cell_state tensor + + {-1.00001, 1.0}, // input_layer_norm_coefficient tensor + {-1.00001, 1.0}, // forget_layer_norm_coefficient tensor + {-1.00001, 1.0}, // cell_layer_norm_coefficient tensor + {-1.00001, 1.0}, // output_layer_norm_coefficient tensor + // Output scale is the same as output_state scale and only output_state + // scale is used in the op, so this is only provided for clarity. + {-1.0, 32767.0 / 32768}, // output tensor. + }; + + // The scale and zero point of intermediate tensors. + std::vector> intermediates = { + {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}}; + + // Create model. + UnidirectionalSequenceLSTMIntegerOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*use_layer_norm=*/true, + /*use_8x8_8_implementation=*/false, ranges, intermediates); + // Do allocate. + lstm.PerformAllocateAndDelegate(); + + // Set weights. + lstm.SetInputToInputWeights(input_to_input_weights); + lstm.SetInputToCellWeights(input_to_cell_weights); + lstm.SetInputToForgetWeights(input_to_forget_weights); + lstm.SetInputToOutputWeights(input_to_output_weights); + + lstm.SetInputGateBias(input_gate_bias); + lstm.SetCellBias(cell_gate_bias); + lstm.SetForgetGateBias(forget_gate_bias); + lstm.SetOutputGateBias(output_gate_bias); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights); + + lstm.SetProjectionWeights(projection_weights); + + lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients); + lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients); + lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients); + lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients); + + // Model inputs. sequence -batch - input + const std::vector lstm_input = { + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.8, 0.1, 0.2, 0.4, 0.5, // + 0.2, 0.7, 0.7, 0.1, 0.7, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + }; + + // Expected outputs, n_batch * sequence_length * n_output + const std::vector expected_output = { + 127, 127, -108, -67, 127, 127, -128, 127, 127, + -128, 127, 127, 127, 127, 127, -128, 127, 127, + }; + + // Invoke and verify the result. + lstm.SetInput(lstm_input); + lstm.Invoke(); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output)); +} + +TEST(IntegerUnidirectionalSequenceLstmOpTest, + NoCifg_Peephole_Projection_LayerNorm) { + // Hyper parameters. + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const int sequence_length = 3; + + // Model related weights. + const std::vector input_to_input_weights = { + 0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5, + -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + const std::vector input_to_forget_weights = { + -0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8, + -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + const std::vector input_to_cell_weights = { + -0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6, + 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + const std::vector input_to_output_weights = { + -0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2, + 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + const std::vector input_gate_bias = {0.03, 0.15, 0.22, 0.38}; + + const std::vector forget_gate_bias = {0.1, -0.3, -0.2, 0.1}; + + const std::vector cell_gate_bias = {-0.05, 0.72, 0.25, 0.08}; + + const std::vector output_gate_bias = {0.05, -0.01, 0.2, 0.1}; + + const std::vector recurrent_to_input_weights = { + -0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + const std::vector recurrent_to_cell_weights = { + -0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + const std::vector recurrent_to_forget_weights = { + -0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + const std::vector recurrent_to_output_weights = { + 0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + const std::vector cell_to_input_weights = {0.3, -0.1, 0.1, -0.2}; + + const std::vector cell_to_forget_weights = {0.2, -0.1, 0.1, -0.2}; + + const std::vector cell_to_output_weights = {0.3, -0.1, 0.1, -0.3}; + + const std::vector input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5}; + const std::vector forget_layer_norm_coefficients = {0.2, 0.2, 0.4, + 0.3}; + const std::vector cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8}; + const std::vector output_layer_norm_coefficients = {0.6, 0.2, 0.2, + 0.5}; + + const std::vector projection_weights = { + -0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + // Input ranges. + const std::vector> ranges = { + {-1.0, 127.0 / 128}, // input tensor + {-1.0, 1.0}, // input_to_input_weight tensor + {-1.0, 1.0}, // input_to_forget_weight tensor + {-1.0, 1.0}, // input_to_cell_weight tensor + {-1.0, 1.0}, // input_to_output_weight tensor + + {-1.0, 1.0}, // recurrent_to_input_weight tensor + {-0.9, 0.9}, // recurrent_to_forget_weight tensor + {-1.0, 1.0}, // recurrent_to_cell_weight tensor + {-1.0, 1.0}, // recurrent_to_output_weight tensor + + {-0.3, 0.3}, // cell_to_input_weight tensor + {-0.3, 0.3}, // cell_to_forget_weight tensor + {-0.3, 0.3}, // cell_to_output_weight tensor + + {-100, 100}, // input_gate_bias tensor + {-100, 80}, // forget_gate_bias tensor + {-100, 100}, // cell_gate_bias tensor + {-100, 100}, // output_gate_bias tensor + + {-0.5, 0.5}, // projection_weight tensor + {-1, 1}, // projection_bias tensor + + {-1.0, 32767.0 / 32768}, // output_state tensor + {-1, 1}, // cell_state tensor + + {-0.5, 0.5}, // input_layer_norm_coefficient tensor + {-0.5, 0.5}, // forget_layer_norm_coefficient tensor + {-1.0, 1.0}, // cell_layer_norm_coefficient tensor + {-1.0, 1.0}, // output_layer_norm_coefficient tensor + // Output scale is the same as output_state scale and only output_state + // scale is used in the op, so this is only provided for clarity. + {-1.0, 32767.0 / 32768}, // output tensor. + }; + + // The scale and zero point of intermediate tensors. + std::vector> intermediates = { + {0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}}; + + // Create model. + UnidirectionalSequenceLSTMIntegerOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*use_layer_norm=*/true, + /*use_8x8_8_implementation=*/false, ranges, intermediates); + + // Do allocate. + lstm.PerformAllocateAndDelegate(); + + // Set weights. + lstm.SetInputToInputWeights(input_to_input_weights); + lstm.SetInputToCellWeights(input_to_cell_weights); + lstm.SetInputToForgetWeights(input_to_forget_weights); + lstm.SetInputToOutputWeights(input_to_output_weights); + + lstm.SetInputGateBias(input_gate_bias); + lstm.SetCellBias(cell_gate_bias); + lstm.SetForgetGateBias(forget_gate_bias); + lstm.SetOutputGateBias(output_gate_bias); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights); + + lstm.SetCellToInputWeights(cell_to_input_weights); + lstm.SetCellToForgetWeights(cell_to_forget_weights); + lstm.SetCellToOutputWeights(cell_to_output_weights); + + lstm.SetProjectionWeights(projection_weights); + + lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients); + lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients); + lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients); + lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients); + + // Model inputs. sequence -batch - input + const std::vector lstm_input = { + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.8, 0.1, 0.2, 0.4, 0.5, // + 0.2, 0.7, 0.7, 0.1, 0.7, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + 0.7, 0.8, 0.1, 0.2, 0.3, // + 0.3, 0.2, 0.9, 0.8, 0.1, // + }; + + // Expected outputs, n_batch * sequence_length * n_output + const std::vector expected_output = { + 127, 127, -16, -21, 127, 127, 23, 127, 127, + -128, 127, 127, 127, 127, 127, -128, 127, 127, + }; + + // Invoke and verify the result. + lstm.SetInput(lstm_input); + lstm.Invoke(); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output)); +} + #define QUANTIZE_PARAMETER_TEST(test) \ INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true})); diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index bd8f39bb925..73cd4cc3f0c 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -81,6 +81,7 @@ cc_library( deps = [ ":micro_utils", ":op_resolvers", + "//tensorflow/lite:type_to_tflitetype", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:kernel_util", diff --git a/tensorflow/lite/micro/benchmarks/Makefile.inc b/tensorflow/lite/micro/benchmarks/Makefile.inc index b50a51758c0..a47bc2e723a 100644 --- a/tensorflow/lite/micro/benchmarks/Makefile.inc +++ b/tensorflow/lite/micro/benchmarks/Makefile.inc @@ -27,15 +27,9 @@ tensorflow/lite/micro/examples/person_detection_experimental/person_detect_model $(eval $(call microlite_test,keyword_benchmark,\ $(KEYWORD_BENCHMARK_SRCS),$(KEYWORD_BENCHMARK_HDRS))) -# We can not build the person detection benchmarks for STM32F4 due to -# insufficient flash (see issue #43743 for more details). Currently disabling -# this benchmark for STM32F4 but we can revisit this later. -ifneq ($(TARGET), stm32f4) - $(eval $(call microlite_test,person_detection_benchmark,\ $(PERSON_DETECTION_BENCHMARK_SRCS),$(PERSON_DETECTION_BENCHMARK_HDRS))) $(eval $(call microlite_test,person_detection_experimental_benchmark,\ $(PERSON_DETECTION_EXPERIMENTAL_BENCHMARK_SRCS),$(PERSON_DETECTION_EXPERIMENTAL_BENCHMARK_HDRS))) -endif diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/README.md b/tensorflow/lite/micro/cortex_m_gcc_generic/README.md deleted file mode 100644 index 4d5f85b24ea..00000000000 --- a/tensorflow/lite/micro/cortex_m_gcc_generic/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Generic Cortex-Mx customizations - -The customization requires a definition where the debug log goes to. The purpose -of the generic Cortex-Mx target is to generate a TFLM library file for use in -application projects outside of this repo. As the chip HAL and the board -specific layer are only defined in the application project, the TFLM library -cannot write the debug log anywhere. Instead, we allow the application layer to -register a callback function for writing the TFLM kernel debug log. - -# Usage - -See debug_log_callback.h diff --git a/tensorflow/lite/micro/cortex_m_generic/README.md b/tensorflow/lite/micro/cortex_m_generic/README.md new file mode 100644 index 00000000000..69e65944d4f --- /dev/null +++ b/tensorflow/lite/micro/cortex_m_generic/README.md @@ -0,0 +1,65 @@ + + +# Generic Cortex-Mx customizations + +The customization requires a definition where the debug log goes to. The purpose +of the generic Cortex-Mx target is to generate a TFLM library file for use in +application projects outside of this repo. As the chip HAL and the board +specific layer are only defined in the application project, the TFLM library +cannot write the debug log anywhere. Instead, we allow the application layer to +register a callback function for writing the TFLM kernel debug log. + +# Usage + +See debug_log_callback.h + +# How to build + +Required parameters: + + - TARGET: cortex_m_generic + - TARGET_ARCH: cortex-mXX (For all options see: tensorflow/lite/micro/tools/make/targets/cortex_m_generic_makefile.inc) + +Optional parameters: + + - TOOLCHAIN: gcc (default) or armmclang + - For Cortex-M55, ARM Compiler 6.14 or later is required. + +Some examples: + +Building with arm-gcc + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m7 microlite +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m7 TAGS=cmsis-nn microlite + +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m4 TAGS=cmsis-nn microlite +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m4+fp TAGS=cmsis-nn microlite +``` + +Building with armclang + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TOOLCHAIN=armclang TARGET=cortex_m_generic TARGET_ARCH=cortex-m55 microlite +make -f tensorflow/lite/micro/tools/make/Makefile TOOLCHAIN=armclang TARGET=cortex_m_generic TARGET_ARCH=cortex-m55 TAGS=cmsis-nn microlite +make -f tensorflow/lite/micro/tools/make/Makefile TOOLCHAIN=armclang TARGET=cortex_m_generic TARGET_ARCH=cortex-m55+nofp TAGS=cmsis-nn microlite +``` + +The Tensorflow Lite Micro makefiles download a specific version of the arm-gcc +compiler to tensorflow/lite/micro/tools/make/downloads/gcc_embedded. + +If desired, a different version can be used by providing `TARGET_TOOLCHAIN_ROOT` +option to the Makefile: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m4+fp TARGET_TOOLCHAIN_ROOT=/path/to/arm-gcc/ microlite +``` + +Similarly, `TAGS=cmsis-nn` downloads a specific version of CMSIS to +tensorflow/lite/micro/tools/make/downloads/cmsis. While this is the only version +that is regularly tested, you can use your own version of CMSIS as well by +providing `CMSIS_PATH` to the Makefile: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=cortex_m_generic TARGET_ARCH=cortex-m4+fp TAGS=cmsis-nn CMSIS_PATH=/path/to/own/cmsis microlite +``` diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc b/tensorflow/lite/micro/cortex_m_generic/debug_log.cc similarity index 91% rename from tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc rename to tensorflow/lite/micro/cortex_m_generic/debug_log.cc index fce512e199b..bc79d439170 100644 --- a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc +++ b/tensorflow/lite/micro/cortex_m_generic/debug_log.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Implementation for the DebugLog() function that prints to the debug logger on -// an generic cortex-m device. +// an generic Cortex-M device. #ifdef __cplusplus extern "C" { @@ -22,7 +22,7 @@ extern "C" { #include "tensorflow/lite/micro/debug_log.h" -#include "tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h" +#include "tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h" static DebugLogCallback debug_log_callback = nullptr; diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h b/tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h similarity index 83% rename from tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h rename to tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h index d462c8db368..c1afd19a578 100644 --- a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h +++ b/tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ -#define TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ +#ifndef TENSORFLOW_LITE_MICRO_CORTEX_M_GENERIC_DEBUG_LOG_CALLBACK_H_ +#define TENSORFLOW_LITE_MICRO_CORTEX_M_GENERIC_DEBUG_LOG_CALLBACK_H_ // The application layer must implement and register a callback before calling // the network in a way similar to @@ -46,4 +46,4 @@ void RegisterDebugLogCallback(DebugLogCallback callback); } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ +#endif // TENSORFLOW_LITE_MICRO_CORTEX_M_GENERIC_DEBUG_LOG_CALLBACK_H_ diff --git a/tensorflow/lite/micro/debug_log.cc b/tensorflow/lite/micro/debug_log.cc index 7ef582bd376..46ca253a6d5 100644 --- a/tensorflow/lite/micro/debug_log.cc +++ b/tensorflow/lite/micro/debug_log.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,15 @@ limitations under the License. #include "tensorflow/lite/micro/debug_log.h" +#ifndef TF_LITE_STRIP_ERROR_STRINGS #include +#endif -extern "C" void DebugLog(const char* s) { fprintf(stderr, "%s", s); } +extern "C" void DebugLog(const char* s) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get + // maximum reduction in binary size. This is because we have DebugLog calls + // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR. + fprintf(stderr, "%s", s); +#endif +} diff --git a/tensorflow/lite/micro/docs/memory_management.md b/tensorflow/lite/micro/docs/memory_management.md index a936cb6d7c3..36b7228fe08 100644 --- a/tensorflow/lite/micro/docs/memory_management.md +++ b/tensorflow/lite/micro/docs/memory_management.md @@ -1,3 +1,5 @@ + + | TAIL | - -## | | | | - -* Lowest Address Highest Address * ``` +``` +-------------------------------------------------------------------------------- +| | | | +| HEAD |<-- TEMPORARY -->| TAIL | +| | | | +-------------------------------------------------------------------------------- +* Lowest Address Highest Address * +``` ### Head Section @@ -129,20 +132,18 @@ interpreter.PrintAllocations(); The output of this call will look something similar to this (output from the [memory_arena_threshold_test](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/memory_arena_threshold_test.cc#L205)): -`sh [RecordingMicroAllocator] Arena allocation total 9568 bytes + +```bash +[RecordingMicroAllocator] Arena allocation total 9568 bytes [RecordingMicroAllocator] Arena allocation head 7744 bytes [RecordingMicroAllocator] Arena allocation tail 1824 bytes -[RecordingMicroAllocator] 'TfLiteEvalTensor data' used 360 bytes with alignment -overhead (requested 360 bytes for 15 allocations) [RecordingMicroAllocator] -'Persistent TfLiteTensor data' used 0 bytes with alignment overhead (requested 0 -bytes for 0 tensors) [RecordingMicroAllocator] 'Persistent TfLiteTensor -quantization data' used 0 bytes with alignment overhead (requested 0 bytes for 0 -allocations) [RecordingMicroAllocator] 'TfLiteTensor variable buffer data' used -0 bytes with alignment overhead (requested 0 bytes for 0 allocations) -[RecordingMicroAllocator] 'NodeAndRegistration struct' used 392 bytes with -alignment overhead (requested 392 bytes for 7 NodeAndRegistration structs) -[RecordingMicroAllocator] 'Operator runtime data' used 136 bytes with alignment -overhead (requested 136 bytes for 5 OpData structs)` +[RecordingMicroAllocator] 'TfLiteEvalTensor data' used 360 bytes with alignment overhead (requested 360 bytes for 15 allocations) +[RecordingMicroAllocator] 'Persistent TfLiteTensor data' used 0 bytes with alignment overhead (requested 0 bytes for 0 tensors) +[RecordingMicroAllocator] 'Persistent TfLiteTensor quantization data' used 0 bytes with alignment overhead (requested 0 bytes for 0 allocations) +[RecordingMicroAllocator] 'TfLiteTensor variable buffer data' used 0 bytes with alignment overhead (requested 0 bytes for 0 allocations) +[RecordingMicroAllocator] 'NodeAndRegistration struct' used 392 bytes with alignment overhead (requested 392 bytes for 7 NodeAndRegistration structs) +[RecordingMicroAllocator] 'Operator runtime data' used 136 bytes with alignment overhead (requested 136 bytes for 5 OpData structs) +``` ### Allocation Section Details diff --git a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf index e36145c332a..f4d8a9fed1d 100644 --- a/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf +++ b/tensorflow/lite/micro/examples/hello_world/zephyr_riscv/prj.conf @@ -14,3 +14,4 @@ # ============================================================================== CONFIG_CPLUSPLUS=y CONFIG_NEWLIB_LIBC=y +CONFIG_NETWORKING=n diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc index 188298524dd..c2a30280f62 100644 --- a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc +++ b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc @@ -64,7 +64,7 @@ int main(int argc, char** argv) { constexpr int tensor_arena_size = 50 * 1024; uint8_t tensor_arena[tensor_arena_size]; - tflite::MicroInterpreter interpreter(model, resolver, tensor_arena, + tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena, tensor_arena_size, error_reporter); interpreter.AllocateTensors(); diff --git a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc index 22ea96f6d49..ad0c5e6268e 100644 --- a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc +++ b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc @@ -27,6 +27,7 @@ limitations under the License. // Create an area of memory to use for input, output, and intermediate arrays. constexpr int tensor_arena_size = 93 * 1024; +__attribute__((section(".bss.NoInit"), aligned(16))) uint8_t tensor_arena[tensor_arena_size]; TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/hexagon/micro_time.cc b/tensorflow/lite/micro/hexagon/micro_time.cc new file mode 100644 index 00000000000..9baf77b5653 --- /dev/null +++ b/tensorflow/lite/micro/hexagon/micro_time.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Hexagon timer implementation. +// To include this with make, add TARGET=hexagon. +#include "tensorflow/lite/micro/micro_time.h" + +#include + +namespace tflite { + +int32_t ticks_per_second() { return CLOCKS_PER_SEC; } + +int32_t GetCurrentTimeTicks() { return clock(); } + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 6eaf3549b32..e9d3faaf027 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -216,7 +216,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -231,7 +230,6 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -243,10 +241,8 @@ tflite_micro_cc_test( "fully_connected_test.cc", ], deps = [ - ":fully_connected", ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", @@ -290,7 +286,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -305,7 +300,6 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_utils", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -319,7 +313,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -389,7 +382,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -403,7 +395,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -431,7 +422,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -488,7 +478,6 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -503,7 +492,6 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -582,8 +570,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -597,8 +583,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -644,9 +628,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -674,7 +656,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -688,7 +669,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", @@ -758,8 +738,6 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", - "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], @@ -771,7 +749,18 @@ tflite_micro_cc_test( deps = [ ":kernel_runner", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + +cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", diff --git a/tensorflow/lite/micro/kernels/activations.cc b/tensorflow/lite/micro/kernels/activations.cc index b6feb786a95..a92d5c73820 100644 --- a/tensorflow/lite/micro/kernels/activations.cc +++ b/tensorflow/lite/micro/kernels/activations.cc @@ -205,12 +205,12 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input != nullptr); if (input->type == kTfLiteInt8) { - data->six_int8 = FloatToAsymmetricQuantizedInt8(6.0f, input->params.scale, - input->params.zero_point); + data->six_int8 = FloatToQuantizedType(6.0f, input->params.scale, + input->params.zero_point); data->zero_int8 = input->params.zero_point; } else if (input->type == kTfLiteUInt8) { - data->six_uint8 = FloatToAsymmetricQuantizedUInt8(6.0f, input->params.scale, - input->params.zero_point); + data->six_uint8 = FloatToQuantizedType(6.0f, input->params.scale, + input->params.zero_point); data->zero_uint8 = input->params.zero_point; } diff --git a/tensorflow/lite/micro/kernels/activations_test.cc b/tensorflow/lite/micro/kernels/activations_test.cc index edbe717bedb..3a51472f9bb 100644 --- a/tensorflow/lite/micro/kernels/activations_test.cc +++ b/tensorflow/lite/micro/kernels/activations_test.cc @@ -35,8 +35,8 @@ void TestReluFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; @@ -68,8 +68,8 @@ void TestRelu6Float(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; @@ -123,8 +123,8 @@ void TestReluUint8(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricQuantize(golden, golden_quantized, output_elements_count, - output_scale, output_zero_point); + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); for (int i = 0; i < output_elements_count; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); @@ -164,8 +164,8 @@ void TestRelu6Uint8(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricQuantize(golden, golden_quantized, output_elements_count, - output_scale, output_zero_point); + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); for (int i = 0; i < output_elements_count; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); @@ -204,8 +204,8 @@ void TestReluInt8(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricQuantize(golden, golden_quantized, output_elements_count, - output_scale, output_zero_point); + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); for (int i = 0; i < output_elements_count; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); @@ -244,8 +244,8 @@ void TestRelu6Int8(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricQuantize(golden, golden_quantized, output_elements_count, - output_scale, output_zero_point); + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); for (int i = 0; i < output_elements_count; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); diff --git a/tensorflow/lite/micro/kernels/add_test.cc b/tensorflow/lite/micro/kernels/add_test.cc index 241dda7d090..a11b73c3290 100644 --- a/tensorflow/lite/micro/kernels/add_test.cc +++ b/tensorflow/lite/micro/kernels/add_test.cc @@ -100,9 +100,9 @@ void TestAddFloat(const int* input1_dims_data, const float* input1_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; ValidateAddGoldens(tensors, tensors_size, expected_output, output_data, @@ -136,9 +136,8 @@ void TestAddQuantized(const int* input1_dims_data, const float* input1_data, tflite::testing::CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point), }; - tflite::AsymmetricQuantize(golden, golden_quantized, - ElementCount(*output_dims), output_scale, - output_zero_point); + tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims), + output_scale, output_zero_point); ValidateAddGoldens(tensors, tensors_size, golden_quantized, output_data, ElementCount(*output_dims), activation); diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc index 55ef2650bef..4522421fa56 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/conv.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc @@ -66,6 +66,7 @@ struct OpData { int32_t output_activation_max; }; +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) inline PaddingType RuntimePaddingType(TfLitePadding padding) { switch (padding) { case TfLitePadding::kTfLitePaddingSame: @@ -77,6 +78,7 @@ inline PaddingType RuntimePaddingType(TfLitePadding padding) { return PaddingType::kNone; } } +#endif bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, @@ -194,7 +196,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->output_zero_point = output->params.zero_point; return kTfLiteOk; -} // namespace conv +} void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, const OpData& data, @@ -259,10 +261,10 @@ TfLiteStatus EvalMliQuantizedPerChannel( mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; - ConvertToMliTensor(input, &mli_in); - ConvertToMliTensorPerChannel(filter, &mli_weights); - ConvertToMliTensorPerChannel(bias, &mli_bias); - ConvertToMliTensor(output, &mli_out); + ops::micro::ConvertToMliTensor(input, &mli_in); + ops::micro::ConvertToMliTensorPerChannel(filter, &mli_weights); + ops::micro::ConvertToMliTensorPerChannel(bias, &mli_bias); + ops::micro::ConvertToMliTensor(output, &mli_out); if (params->activation == kTfLiteActRelu) { cfg.relu.type = MLI_RELU_GEN; @@ -313,14 +315,16 @@ TfLiteStatus EvalMliQuantizedPerChannel( mli_tensor out_local = mli_out; mli_mov_cfg_t copy_config; mli_mov_cfg_for_copy(©_config); - TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_conv_tensors( + TF_LITE_ENSURE_STATUS(ops::micro::get_arc_scratch_buffer_for_conv_tensors( context, &in_local, &weights_local, &bias_local, &out_local)); - TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_io( + TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_io( &in_local, &out_local, kernel_height, cfg.stride_height, cfg.padding_top, cfg.padding_bottom, &in_slice_height, &out_slice_height)); - TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( - &weights_local, &bias_local, weight_out_ch_dimension, &slice_channels)); + TF_LITE_ENSURE_STATUS( + ops::micro::arc_scratch_buffer_calc_slice_size_weights( + &weights_local, &bias_local, weight_out_ch_dimension, + &slice_channels)); /* is_local indicates that the tensor is already in local memory, so in that case the original tensor can be used, @@ -330,10 +334,12 @@ TfLiteStatus EvalMliQuantizedPerChannel( const bool w_is_local = weights_local.data == mli_weights.data; const bool b_is_local = bias_local.data == mli_bias.data; - TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, slice_channels); - TensorSlicer b_slice(&mli_bias, weight_out_ch_dimension, slice_channels); - TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, slice_channels, - 0, 0, 0, true); + ops::micro::TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, + slice_channels); + ops::micro::TensorSlicer b_slice(&mli_bias, weight_out_ch_dimension, + slice_channels); + ops::micro::TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, + slice_channels, 0, 0, 0, true); mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; @@ -352,15 +358,16 @@ TfLiteStatus EvalMliQuantizedPerChannel( dimension. for that the sliceHeight has been calculated. The tensor slicer is configured that it will completely slice the nBatch dimension (0) and slice the height dimension (1) in chunks of 'sliceHeight' */ - TensorSlicer in_slice(&mli_in, height_dimension, in_slice_height, - cfg.padding_top, cfg.padding_bottom, overlap); + ops::micro::TensorSlicer in_slice(&mli_in, height_dimension, + in_slice_height, cfg.padding_top, + cfg.padding_bottom, overlap); /* output tensor is alreade sliced in the output channel dimension. out_ch_slice.Sub() is the tensor for the amount of output channels of this itteration of the weight slice loop. This tensor needs to be further sliced over the batch and height dimension. */ - TensorSlicer out_slice(out_ch_slice.Sub(), height_dimension, - out_slice_height); + ops::micro::TensorSlicer out_slice(out_ch_slice.Sub(), height_dimension, + out_slice_height); /* setup the pointers to the local or remote tensor to make the code * inside the loop easier. */ diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc index d30a5308708..8fe5d307cdd 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc @@ -242,10 +242,10 @@ TfLiteStatus EvalMliQuantizedPerChannel( mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; - ConvertToMliTensor(input, &mli_in); - ConvertToMliTensorPerChannel(filter, &mli_weights); - ConvertToMliTensorPerChannel(bias, &mli_bias); - ConvertToMliTensor(output, &mli_out); + ops::micro::ConvertToMliTensor(input, &mli_in); + ops::micro::ConvertToMliTensorPerChannel(filter, &mli_weights); + ops::micro::ConvertToMliTensorPerChannel(bias, &mli_bias); + ops::micro::ConvertToMliTensor(output, &mli_out); if (params->activation == kTfLiteActRelu) { cfg.relu.type = MLI_RELU_GEN; @@ -301,7 +301,7 @@ TfLiteStatus EvalMliQuantizedPerChannel( mli_mov_cfg_t copy_config; mli_mov_cfg_for_copy(©_config); - TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_conv_tensors( + TF_LITE_ENSURE_STATUS(ops::micro::get_arc_scratch_buffer_for_conv_tensors( context, &in_local, &weights_local, &bias_local, &out_local)); /* is_local indicates that the tensor is already in local memory, so in that case the original tensor can be used, @@ -311,10 +311,10 @@ TfLiteStatus EvalMliQuantizedPerChannel( const bool w_is_local = weights_local.data == mli_weights.data; const bool b_is_local = bias_local.data == mli_bias.data; - TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_io( + TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_io( &in_local, &out_local, kernelHeight, cfg.stride_height, cfg.padding_top, cfg.padding_bottom, &inSliceHeight, &outSliceHeight)); - TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( + TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_weights( &weights_local, &bias_local, weight_out_ch_dimension, &slice_channels)); /* if input channels is not equal to output channels, a channel multiplier @@ -324,13 +324,14 @@ TfLiteStatus EvalMliQuantizedPerChannel( slice_channels = (slice_channels / in_channels) * in_channels; } - TensorSlicer b_slice(&mli_bias, bias_out_ch_dimension, slice_channels); - TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, slice_channels, 0, - 0, 0, true); - TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, slice_channels, - 0, 0, 0, true); - TensorSlicer in_ch_slice(&mli_in, out_tensor_ch_dimension, slice_channels, 0, - 0, 0, true); + ops::micro::TensorSlicer b_slice(&mli_bias, bias_out_ch_dimension, + slice_channels); + ops::micro::TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, + slice_channels, 0, 0, 0, true); + ops::micro::TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, + slice_channels, 0, 0, 0, true); + ops::micro::TensorSlicer in_ch_slice(&mli_in, out_tensor_ch_dimension, + slice_channels, 0, 0, 0, true); mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; @@ -355,14 +356,16 @@ TfLiteStatus EvalMliQuantizedPerChannel( the sliceHeight has been calculated. The tensor slicer is configured that it will completely slice the nBatch dimension (0) and slice the height dimension (1) in chunks of 'sliceHeight' */ - TensorSlicer in_slice(in_ch_slice.Sub(), heightDimension, inSliceHeight, - padding_top, padding_bottom, overlap); + ops::micro::TensorSlicer in_slice(in_ch_slice.Sub(), heightDimension, + inSliceHeight, padding_top, + padding_bottom, overlap); /* output tensor is alreade sliced in the output channel dimension. out_ch_slice.Sub() is the tensor for the amount of output channels of this itteration of the weight slice loop. This tensor needs to be further sliced over the batch and height dimension. */ - TensorSlicer out_slice(out_ch_slice.Sub(), heightDimension, outSliceHeight); + ops::micro::TensorSlicer out_slice(out_ch_slice.Sub(), heightDimension, + outSliceHeight); /* setup the pointers to the local or remote tensor to make the code * inside the loop easier. */ diff --git a/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc index 2d201653efc..ea5c6c6eaf3 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc @@ -29,9 +29,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { namespace { struct OpData { @@ -127,10 +124,10 @@ TfLiteStatus EvalMliQuantizedInt8(TfLiteContext* context, TfLiteNode* node, mli_tensor mli_bias = {}; mli_tensor mli_out = {}; - ConvertToMliTensor(input, &mli_in); - ConvertToMliTensor(filter, &mli_weights); - ConvertToMliTensor(bias, &mli_bias); - ConvertToMliTensor(output, &mli_out); + ops::micro::ConvertToMliTensor(input, &mli_in); + ops::micro::ConvertToMliTensor(filter, &mli_weights); + ops::micro::ConvertToMliTensor(bias, &mli_bias); + ops::micro::ConvertToMliTensor(output, &mli_out); /* The input tensor can have more than 2 dimensions. for the compute this doesn't make any difference because all the inputs or a batch entry will @@ -156,9 +153,10 @@ TfLiteStatus EvalMliQuantizedInt8(TfLiteContext* context, TfLiteNode* node, int slice_size = mli_weights.shape[weight_out_dimension]; /* allocate the local buffers, and compute the slice size */ - TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_fully_connect_tensors( - context, &in_local, &weights_local, &bias_local, &out_local)); - TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( + TF_LITE_ENSURE_STATUS( + ops::micro::get_arc_scratch_buffer_for_fully_connect_tensors( + context, &in_local, &weights_local, &bias_local, &out_local)); + TF_LITE_ENSURE_STATUS(ops::micro::arc_scratch_buffer_calc_slice_size_weights( &weights_local, &bias_local, weight_out_dimension, &slice_size)); int max_out_slice_size = out_local.capacity / mli_hlp_tensor_element_size(&out_local); @@ -172,10 +170,11 @@ TfLiteStatus EvalMliQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const bool w_is_local = weights_local.data == mli_weights.data; const bool b_is_local = bias_local.data == mli_bias.data; - TensorSlicer w_slice(&mli_weights, weight_out_dimension, slice_size); - TensorSlicer b_slice(&mli_bias, weight_out_dimension, slice_size); - TensorSlicer out_ch_slice(&mli_out, out_tensor_dimension, slice_size, 0, 0, 0, - true); + ops::micro::TensorSlicer w_slice(&mli_weights, weight_out_dimension, + slice_size); + ops::micro::TensorSlicer b_slice(&mli_bias, weight_out_dimension, slice_size); + ops::micro::TensorSlicer out_ch_slice(&mli_out, out_tensor_dimension, + slice_size, 0, 0, 0, true); mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; @@ -188,15 +187,15 @@ TfLiteStatus EvalMliQuantizedInt8(TfLiteContext* context, TfLiteNode* node, // Slice the input over the batches (one at a time with the size of a // complete input) - TensorSlicer in_slice(&mli_in, input_size_dimension, - mli_in.shape[input_size_dimension]); + ops::micro::TensorSlicer in_slice(&mli_in, input_size_dimension, + mli_in.shape[input_size_dimension]); /* output tensor is alreade sliced in the output size dimension. out_ch_slice.Sub() is the tensor for the amount of output size of this itteration of the weight slice loop. This tensor needs to be further sliced over the batch */ - TensorSlicer out_slice(out_ch_slice.Sub(), out_tensor_dimension, - slice_size); + ops::micro::TensorSlicer out_slice(out_ch_slice.Sub(), out_tensor_dimension, + slice_size); /* setup the pointers to the local or remote tensor to make the code * inside the loop easier. */ @@ -359,19 +358,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fully_connected - TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc index e20eea22a03..905c6fedf9d 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc @@ -25,13 +25,13 @@ TensorSlicer::TensorSlicer(const mli_tensor* full_tensor, int slice_dim, int slice_size, int padding_pre, int padding_post, int overlap, bool interleave_mode) : full_tensor_(full_tensor), + sub_tensor_{}, + sub_cfg_{}, + done_(false), sliceDim_(slice_dim), pad_pre_(padding_pre), pad_post_(padding_post), - overlap_(overlap), - sub_cfg_{}, - sub_tensor_{}, - done_(false) { + overlap_(overlap) { /* In the interleave mode, the slicing happens from the deepest dimension up to the slice_dim for example in an HWC layout this can mode can be used to slice in the C dimenstion. in this mode the data is not contiguous in memory diff --git a/tensorflow/lite/micro/kernels/arg_min_max_test.cc b/tensorflow/lite/micro/kernels/arg_min_max_test.cc index b85d59555f8..e1e87d39be3 100644 --- a/tensorflow/lite/micro/kernels/arg_min_max_test.cc +++ b/tensorflow/lite/micro/kernels/arg_min_max_test.cc @@ -60,9 +60,9 @@ void TestArgMinMaxFloat(const int* input_dims_data, const float* input_values, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_values, input_dims), - CreateInt32Tensor(axis_values, axis_dims), - CreateInt32Tensor(output, output_dims), + CreateTensor(input_values, input_dims), + CreateTensor(axis_values, axis_dims), + CreateTensor(output, output_dims), }; ValidateArgMinMaxGoldens(tensors, tensors_size, goldens, output, @@ -88,8 +88,8 @@ void TestArgMinMaxQuantized(const int* input_dims_data, TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_values, input_quantized, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(axis_values, axis_dims), - CreateInt32Tensor(output, output_dims), + CreateTensor(axis_values, axis_dims), + CreateTensor(output, output_dims), }; ValidateArgMinMaxGoldens(tensors, tensors_size, goldens, output, diff --git a/tensorflow/lite/micro/kernels/ceil_test.cc b/tensorflow/lite/micro/kernels/ceil_test.cc index a388d285db3..286cbd2f194 100644 --- a/tensorflow/lite/micro/kernels/ceil_test.cc +++ b/tensorflow/lite/micro/kernels/ceil_test.cc @@ -33,8 +33,8 @@ void TestCeil(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; diff --git a/tensorflow/lite/micro/kernels/comparisons_test.cc b/tensorflow/lite/micro/kernels/comparisons_test.cc index 855192baed2..addb08aa4da 100644 --- a/tensorflow/lite/micro/kernels/comparisons_test.cc +++ b/tensorflow/lite/micro/kernels/comparisons_test.cc @@ -61,9 +61,9 @@ void TestComparisonFloat(const TfLiteRegistration& registration, TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateBoolTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; TestComparison(registration, tensors, expected_output_data, output_data); @@ -79,9 +79,9 @@ void TestComparisonBool(const TfLiteRegistration& registration, TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteTensor tensors[tensors_size] = { - CreateBoolTensor(input1_data, input1_dims), - CreateBoolTensor(input2_data, input2_dims), - CreateBoolTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; TestComparison(registration, tensors, expected_output_data, output_data); @@ -97,9 +97,9 @@ void TestComparisonInt(const TfLiteRegistration& registration, TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(input1_data, input1_dims), - CreateInt32Tensor(input2_data, input2_dims), - CreateBoolTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; TestComparison(registration, tensors, expected_output_data, output_data); @@ -122,7 +122,7 @@ void TestComparisonQuantizedUInt8(const TfLiteRegistration& registration, input1_scale, input1_zero_point), CreateQuantizedTensor(input2_data, input2_quantized, input2_dims, input2_scale, input2_zero_point), - CreateBoolTensor(output_data, output_dims), + CreateTensor(output_data, output_dims), }; TestComparison(registration, tensors, expected_output_data, output_data); @@ -145,7 +145,7 @@ void TestComparisonQuantizedInt8(const TfLiteRegistration& registration, input1_scale, input1_zero_point), CreateQuantizedTensor(input2_data, input2_quantized, input2_dims, input2_scale, input2_zero_point), - CreateBoolTensor(output_data, output_dims), + CreateTensor(output_data, output_dims), }; TestComparison(registration, tensors, expected_output_data, output_data); diff --git a/tensorflow/lite/micro/kernels/concatenation_test.cc b/tensorflow/lite/micro/kernels/concatenation_test.cc index d7ed2213c98..cb7e0bff626 100644 --- a/tensorflow/lite/micro/kernels/concatenation_test.cc +++ b/tensorflow/lite/micro/kernels/concatenation_test.cc @@ -38,10 +38,9 @@ void TestConcatenateTwoInputs(const int* input1_dims_data, constexpr int input_size = 2; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims)}; int inputs_array_data[] = {2, 0, 1}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index 9b1b1148176..55efa486234 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -299,6 +299,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE_MSG(context, input->type == filter->type, + "Hybrid models are not supported on TFLite Micro."); + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: EvalFloat(context, node, params, data, input, filter, bias, nullptr, diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index d0d942c53c8..39f4bac9732 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -54,11 +54,8 @@ static TfLiteConvParams common_conv_params = { }; template -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const T* expected_output_data, T* output_data, - int output_length, - TfLiteConvParams* conv_params, - float tolerance = 1e-5) { +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, T* output_data, + int output_length, TfLiteConvParams* conv_params) { int inputs_array_data[] = {3, 0, 1, 2}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 3}; @@ -70,14 +67,24 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, reinterpret_cast(conv_params), micro_test::reporter); const char* init_data = reinterpret_cast(conv_params); - - // TODO(b/154240825): Use a test macro here which fails and returns. TfLiteStatus status = runner.InitAndPrepare(init_data); if (status != kTfLiteOk) { return status; } - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + return runner.Invoke(); +} +template +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const T* expected_output_data, T* output_data, + int output_length, + TfLiteConvParams* conv_params, + float tolerance = 1e-5) { + TfLiteStatus status = InvokeConv(tensors, tensors_size, output_data, + output_length, conv_params); + if (status != kTfLiteOk) { + return status; + } for (int i = 0; i < output_length; ++i) { TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], tolerance); @@ -100,10 +107,10 @@ void TestConvFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(filter_data, filter_dims), - CreateFloatTensor(bias_data, bias_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(filter_data, filter_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(output_data, output_dims), }; TF_LITE_MICRO_EXPECT_EQ( @@ -126,8 +133,8 @@ void TestConvQuantizedPerLayer( TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); const int output_dims_count = ElementCount(*output_dims); - tflite::AsymmetricQuantize(expected_output_data, expected_output_quantized, - output_dims_count, output_scale, 128); + tflite::Quantize(expected_output_data, expected_output_quantized, + output_dims_count, output_scale, 128); constexpr int inputs_size = 3; constexpr int outputs_size = 1; @@ -211,9 +218,8 @@ void TestConvQuantizedPerChannel( output_tensor, }; - tflite::AsymmetricQuantize(expected_output_data, - expected_output_data_quantized, output_dims_count, - output_scale, output_zero_point); + tflite::Quantize(expected_output_data, expected_output_data_quantized, + output_dims_count, output_scale, output_zero_point); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, ValidateConvGoldens(tensors, tensors_size, expected_output_data_quantized, @@ -278,6 +284,64 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized) { &tflite::testing::common_conv_params); } +TF_LITE_MICRO_TEST(InputOutputDifferentTypeIsError) { + using tflite::testing::CreateQuantizedTensor; + using tflite::testing::CreateTensor; + using tflite::testing::IntArrayFromInts; + + TfLiteIntArray* input_dims = IntArrayFromInts(tflite::testing::kInputShape); + TfLiteIntArray* filter_dims = IntArrayFromInts(tflite::testing::kFilterShape); + TfLiteIntArray* bias_dims = IntArrayFromInts(tflite::testing::kBiasShape); + TfLiteIntArray* output_dims = IntArrayFromInts(tflite::testing::kOutputShape); + const int output_dims_count = tflite::ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + + int8_t output_data[tflite::testing::kOutputElements]; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(tflite::testing::kInputData, input_dims), + CreateTensor(tflite::testing::kFilterData, filter_dims), + CreateTensor(tflite::testing::kBiasData, bias_dims), + CreateQuantizedTensor(output_data, output_dims, /*scale=*/0.0f, + /*zero_point=*/0), + }; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteError, tflite::testing::InvokeConv( + tensors, tensors_size, output_data, output_dims_count, + &tflite::testing::common_conv_params)); +} + +TF_LITE_MICRO_TEST(HybridModeIsError) { + using tflite::testing::CreateQuantizedTensor; + using tflite::testing::CreateTensor; + using tflite::testing::IntArrayFromInts; + + TfLiteIntArray* input_dims = IntArrayFromInts(tflite::testing::kInputShape); + TfLiteIntArray* filter_dims = IntArrayFromInts(tflite::testing::kFilterShape); + TfLiteIntArray* bias_dims = IntArrayFromInts(tflite::testing::kBiasShape); + TfLiteIntArray* output_dims = IntArrayFromInts(tflite::testing::kOutputShape); + const int output_dims_count = tflite::ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + + int8_t filter_data[tflite::testing::kFilterElements] = {}; + float output_data[tflite::testing::kOutputElements]; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(tflite::testing::kInputData, input_dims), + CreateQuantizedTensor(filter_data, filter_dims, + /*scale=*/0.0f, + /*zero_point=*/0), + CreateTensor(tflite::testing::kBiasData, bias_dims), + CreateTensor(output_data, output_dims), + }; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteError, tflite::testing::InvokeConv( + tensors, tensors_size, output_data, output_dims_count, + &tflite::testing::common_conv_params)); +} + TF_LITE_MICRO_TEST(SimpleTestDilatedQuantized) { const int output_dims_count = 24; uint8_t output_data[output_dims_count]; @@ -567,8 +631,8 @@ TF_LITE_MICRO_TEST(FilterDimsNotMatchingAffineQuantization) { output_tensor, }; - tflite::AsymmetricQuantize(tflite::testing::kGoldenData, golden_quantized, - output_dims_count, output_scale, 0); + tflite::Quantize(tflite::testing::kGoldenData, golden_quantized, + output_dims_count, output_scale, 0); // Set filter quant to mismatched dimension. TfLiteAffineQuantization* quant = reinterpret_cast( @@ -641,7 +705,7 @@ TF_LITE_MICRO_TEST(BroadcastPerLayerQuantizationToPerChannelShouldMatchGolden) { tflite::testing::kBiasElements, input_scale * output_scale); TfLiteTensor bias_tensor = - tflite::testing::CreateInt32Tensor(bias_quantized, bias_dims); + tflite::testing::CreateTensor(bias_quantized, bias_dims); int bias_zero_points[2] = {1, 0}; float bias_scales[2] = {1, input_scale * filter_scale}; @@ -670,8 +734,8 @@ TF_LITE_MICRO_TEST(BroadcastPerLayerQuantizationToPerChannelShouldMatchGolden) { output_tensor, }; - tflite::AsymmetricQuantize(tflite::testing::kGoldenData, golden_quantized, - output_dims_count, output_scale, 0); + tflite::Quantize(tflite::testing::kGoldenData, golden_quantized, + output_dims_count, output_scale, 0); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, tflite::testing::ValidateConvGoldens( @@ -767,7 +831,7 @@ TF_LITE_MICRO_TEST(Int8Input32x1Filter32x32ShouldMatchGolden) { tflite::SymmetricQuantize(bias_values, bias_quantized, kSampleSize, input_scale * output_scale); TfLiteTensor bias_tensor = - tflite::testing::CreateInt32Tensor(bias_quantized, bias_dims); + tflite::testing::CreateTensor(bias_quantized, bias_dims); // There is a single zero point of 0, and a single scale of // input_scale * filter_scale. @@ -802,9 +866,8 @@ TF_LITE_MICRO_TEST(Int8Input32x1Filter32x32ShouldMatchGolden) { }; int8_t golden_quantized[kSampleSize]; - tflite::AsymmetricQuantize(expected_output, golden_quantized, - output_dims_count, output_scale, - output_zero_point); + tflite::Quantize(expected_output, golden_quantized, output_dims_count, + output_scale, output_zero_point); // Rounding errors due to quantization should not exceed 1. constexpr int kQuantizationTolerance = 1; diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc index 358d508a564..d324c9d033b 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc @@ -96,10 +96,10 @@ void TestDepthwiseConvFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(filter_data, filter_dims), - CreateFloatTensor(bias_data, bias_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(filter_data, filter_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(output_data, output_dims), }; ValidateDepthwiseConvGoldens(expected_output_data, output_dims_count, @@ -151,8 +151,8 @@ void TestDepthwiseConvQuantizedPerLayer( IntArrayFromInts(bias_zero_points), 0}; tensors[2].quantization = {kTfLiteAffineQuantization, &bias_quant}; - AsymmetricQuantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); ValidateDepthwiseConvGoldens(golden_quantized, output_dims_count, conv_params, 1.0, tensors_size, tensors); } @@ -217,8 +217,8 @@ void TestDepthwiseConvQuantizedPerChannel( output_tensor, }; - AsymmetricQuantize(expected_output_data, expected_output_data_quantized, - output_dims_count, output_scale, output_zero_point); + Quantize(expected_output_data, expected_output_data_quantized, + output_dims_count, output_scale, output_zero_point); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, ValidateDepthwiseConvGoldens(expected_output_data_quantized, @@ -810,7 +810,7 @@ TF_LITE_MICRO_TEST(PerChannelBroadcastQuantizationParams) { tflite::SymmetricQuantize(bias_values, bias_quantized, bias_elements, input_scale * output_scale); TfLiteTensor bias_tensor = - tflite::testing::CreateInt32Tensor(bias_quantized, bias_dims); + tflite::testing::CreateTensor(bias_quantized, bias_dims); int bias_zero_points[2] = {1, 0}; float bias_scales[2] = {1, input_scale * filter_scale}; @@ -839,8 +839,8 @@ TF_LITE_MICRO_TEST(PerChannelBroadcastQuantizationParams) { output_tensor, }; - tflite::AsymmetricQuantize(golden, golden_quantized, output_dims_count, - output_scale, 0); + tflite::Quantize(golden, golden_quantized, output_dims_count, output_scale, + 0); TfLiteDepthwiseConvParams conv_params; conv_params.activation = kTfLiteActNone; @@ -954,7 +954,7 @@ TF_LITE_MICRO_TEST(Int8Input32x4Filter32x4ShouldMatchGolden) { tflite::SymmetricQuantize(bias_values, bias_quantized, bias_elements, input_scale * output_scale); TfLiteTensor bias_tensor = - tflite::testing::CreateInt32Tensor(bias_quantized, bias_dims); + tflite::testing::CreateTensor(bias_quantized, bias_dims); // Set zero point and scale arrays with a single element for each. int bias_zero_points[] = {1, 0}; @@ -989,8 +989,7 @@ TF_LITE_MICRO_TEST(Int8Input32x4Filter32x4ShouldMatchGolden) { }; int8_t golden_quantized[output_elements]; - tflite::AsymmetricQuantize(golden, golden_quantized, output_elements, - output_scale, 0); + tflite::Quantize(golden, golden_quantized, output_elements, output_scale, 0); // Errors due to quantization should not exceed 1. constexpr int kQuantizationTolerance = 1; diff --git a/tensorflow/lite/micro/kernels/dequantize_test.cc b/tensorflow/lite/micro/kernels/dequantize_test.cc index 86059c63647..8664595a99c 100644 --- a/tensorflow/lite/micro/kernels/dequantize_test.cc +++ b/tensorflow/lite/micro/kernels/dequantize_test.cc @@ -61,7 +61,7 @@ void TestDequantizeToFloat(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_data_quantized, input_dims, scale, zero_point), - CreateFloatTensor(output_data, output_dims), + CreateTensor(output_data, output_dims), }; ValidateDequantizeGoldens(tensors, tensors_size, expected_output_data, @@ -84,7 +84,7 @@ void TestDequantizeToInt32(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_data_quantized, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(output_data, output_dims), + CreateTensor(output_data, output_dims), }; tensors[1].params.scale = output_scale; diff --git a/tensorflow/lite/micro/kernels/elementwise_test.cc b/tensorflow/lite/micro/kernels/elementwise_test.cc index 1f3b49b3616..665f8d4e0d6 100644 --- a/tensorflow/lite/micro/kernels/elementwise_test.cc +++ b/tensorflow/lite/micro/kernels/elementwise_test.cc @@ -35,9 +35,8 @@ void TestElementwiseFloat(const TfLiteRegistration& registration, constexpr int input_size = 1; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims)}; // Place a unique value in the uninitialized output buffer. for (int i = 0; i < output_dims_count; ++i) { @@ -72,9 +71,8 @@ void TestElementwiseBool(const TfLiteRegistration& registration, constexpr int input_size = 1; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateBoolTensor(input_data, input_dims), - CreateBoolTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims)}; // Place false in the uninitialized output buffer. for (int i = 0; i < output_dims_count; ++i) { diff --git a/tensorflow/lite/micro/kernels/floor_test.cc b/tensorflow/lite/micro/kernels/floor_test.cc index dc9086a07cd..9e9da1ddd57 100644 --- a/tensorflow/lite/micro/kernels/floor_test.cc +++ b/tensorflow/lite/micro/kernels/floor_test.cc @@ -34,8 +34,8 @@ void TestFloor(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 3f113010485..ca0a4bcf758 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -276,10 +276,10 @@ TfLiteStatus TestFullyConnectedFloat( constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(weights_data, weights_dims), - CreateFloatTensor(bias_data, bias_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(weights_data, weights_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(output_data, output_dims), }; return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 1e-4f, @@ -317,8 +317,8 @@ TfLiteStatus TestFullyConnectedQuantized( output_zero_point), }; - AsymmetricQuantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 0.0f, output_dims_count, golden_quantized, diff --git a/tensorflow/lite/micro/kernels/hard_swish_test.cc b/tensorflow/lite/micro/kernels/hard_swish_test.cc index 91345870023..2b92e902aa3 100644 --- a/tensorflow/lite/micro/kernels/hard_swish_test.cc +++ b/tensorflow/lite/micro/kernels/hard_swish_test.cc @@ -114,8 +114,8 @@ void TestHardSwishQuantized(int size, const T* output_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricDequantize(output_data, output_elements_count, output_scale, - output_zero_point, dequantized_output); + Dequantize(output_data, output_elements_count, output_scale, + output_zero_point, dequantized_output); for (int i = 0; i < output_elements_count; ++i) { TF_LITE_MICRO_EXPECT_NEAR(float_ref_output_values[i], dequantized_output[i], @@ -194,8 +194,8 @@ void TestHardSwishQuantizedBias(const int size, const T* output_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - AsymmetricDequantize(output_data, output_elements_count, output_scale, - output_zero_point, dequantized_output); + Dequantize(output_data, output_elements_count, output_scale, + output_zero_point, dequantized_output); float sum_diff = 0; for (int i = 0; i < size; i++) { @@ -229,8 +229,8 @@ void TestHardSwishFloat(const int size, float* output_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(float_input_values, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(float_input_values, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; diff --git a/tensorflow/lite/micro/kernels/l2norm_test.cc b/tensorflow/lite/micro/kernels/l2norm_test.cc index b37c6394a66..cac39278f10 100644 --- a/tensorflow/lite/micro/kernels/l2norm_test.cc +++ b/tensorflow/lite/micro/kernels/l2norm_test.cc @@ -32,7 +32,7 @@ constexpr float kOutputMax = 127.0 / 128.0; TfLiteTensor CreateL2NormTensor(const float* data, TfLiteIntArray* dims, bool is_input) { - return CreateFloatTensor(data, dims); + return CreateTensor(data, dims); } template diff --git a/tensorflow/lite/micro/kernels/logical_test.cc b/tensorflow/lite/micro/kernels/logical_test.cc index 67606e772e4..cca2e6a2eb7 100644 --- a/tensorflow/lite/micro/kernels/logical_test.cc +++ b/tensorflow/lite/micro/kernels/logical_test.cc @@ -38,9 +38,9 @@ void TestLogicalOp(const TfLiteRegistration& registration, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateBoolTensor(input1_data, input1_dims), - CreateBoolTensor(input2_data, input2_dims), - CreateBoolTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {2, 0, 1}; diff --git a/tensorflow/lite/micro/kernels/logistic_test.cc b/tensorflow/lite/micro/kernels/logistic_test.cc index 7ba2dd8f52f..3099f2972dc 100644 --- a/tensorflow/lite/micro/kernels/logistic_test.cc +++ b/tensorflow/lite/micro/kernels/logistic_test.cc @@ -79,8 +79,8 @@ void TestLogisticFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; ValidateLogisticGoldens(tensors, tensors_size, output_data, golden, @@ -108,8 +108,8 @@ void TestLogisticQuantized(const int* input_dims_data, const float* input_data, output_zero_point), }; - tflite::AsymmetricQuantize(golden, golden_quantized, output_elements_count, - output_scale, output_zero_point); + tflite::Quantize(golden, golden_quantized, output_elements_count, + output_scale, output_zero_point); ValidateLogisticGoldens(tensors, tensors_size, output_data, golden_quantized, output_elements_count, 1.0); } diff --git a/tensorflow/lite/micro/kernels/maximum_minimum_test.cc b/tensorflow/lite/micro/kernels/maximum_minimum_test.cc index 0db93ff18cb..9c0eac0726e 100644 --- a/tensorflow/lite/micro/kernels/maximum_minimum_test.cc +++ b/tensorflow/lite/micro/kernels/maximum_minimum_test.cc @@ -38,9 +38,9 @@ void TestMaxMinFloat(const TfLiteRegistration& registration, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {2, 0, 1}; @@ -118,9 +118,9 @@ void TestMaxMinQuantizedInt32( constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(input1_data, input1_dims), - CreateInt32Tensor(input2_data, input2_dims), - CreateInt32Tensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {2, 0, 1}; diff --git a/tensorflow/lite/micro/kernels/mul_test.cc b/tensorflow/lite/micro/kernels/mul_test.cc index 8503a1502d1..5c0fe275e07 100644 --- a/tensorflow/lite/micro/kernels/mul_test.cc +++ b/tensorflow/lite/micro/kernels/mul_test.cc @@ -80,9 +80,9 @@ void TestMulFloat(const int* input1_dims_data, const float* input1_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; ValidateMulGoldens(tensors, tensors_size, activation, golden, @@ -114,8 +114,8 @@ void TestMulQuantized(const int* input1_dims_data, const float* input1_data, CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point)}; - AsymmetricQuantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); ValidateMulGoldens(tensors, tensors_size, activation, golden_quantized, output_dims_count, 1.0f, output_data); diff --git a/tensorflow/lite/micro/kernels/neg_test.cc b/tensorflow/lite/micro/kernels/neg_test.cc index f3c0e7d36a8..40111dca0d4 100644 --- a/tensorflow/lite/micro/kernels/neg_test.cc +++ b/tensorflow/lite/micro/kernels/neg_test.cc @@ -34,8 +34,8 @@ void TestNegFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; diff --git a/tensorflow/lite/micro/kernels/pack_test.cc b/tensorflow/lite/micro/kernels/pack_test.cc index 5ac80d698b5..d523db3e983 100644 --- a/tensorflow/lite/micro/kernels/pack_test.cc +++ b/tensorflow/lite/micro/kernels/pack_test.cc @@ -61,10 +61,9 @@ void TestPackTwoInputsFloat(const int* input1_dims_data, constexpr int input_size = 2; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims)}; TfLitePackParams builtin_data = { .values_count = 2, @@ -95,11 +94,10 @@ void TestPackThreeInputsFloat( constexpr int input_size = 3; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(input3_data, input3_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(input3_data, input3_dims), + CreateTensor(output_data, output_dims)}; TfLitePackParams builtin_data = { .values_count = 3, @@ -167,10 +165,9 @@ void TestPackTwoInputsQuantized32(const int* input1_dims_data, constexpr int input_size = 2; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(input1_data, input1_dims), - CreateInt32Tensor(input2_data, input2_dims), - CreateInt32Tensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims)}; TfLitePackParams builtin_data = { .values_count = 2, diff --git a/tensorflow/lite/micro/kernels/pad_test.cc b/tensorflow/lite/micro/kernels/pad_test.cc index e94bc993fea..859fc1b05e9 100644 --- a/tensorflow/lite/micro/kernels/pad_test.cc +++ b/tensorflow/lite/micro/kernels/pad_test.cc @@ -101,10 +101,9 @@ void TestPadFloat(const int* input_dims_data, const float* input_data, constexpr int inputs_size = 2; constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(pad_data, pad_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input_data, input_dims), + CreateTensor(pad_data, pad_dims), + CreateTensor(output_data, output_dims)}; // Pad tensor must be constant. tensors[1].allocation_type = kTfLiteMmapRo; @@ -130,10 +129,9 @@ void TestPadV2Float(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(pad_data, pad_dims), - CreateFloatTensor(&pad_value, pad_value_dims), - CreateFloatTensor(output_data, output_dims)}; + CreateTensor(input_data, input_dims), CreateTensor(pad_data, pad_dims), + CreateTensor(&pad_value, pad_value_dims), + CreateTensor(output_data, output_dims)}; // Pad tensor must be constant. tensors[1].allocation_type = kTfLiteMmapRo; @@ -161,15 +159,15 @@ void TestPadQuantized(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_quantized, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(pad_data, pad_dims), + CreateTensor(pad_data, pad_dims), CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point)}; // Pad tensor must be constant. tensors[1].allocation_type = kTfLiteMmapRo; - tflite::AsymmetricQuantize(golden, golden_quantized, output_dims_count, - output_scale, output_zero_point); + tflite::Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); TF_LITE_MICRO_EXPECT_EQ( expected_status, ValidatePadGoldens(tensors, tensors_size, golden_quantized, output_data, @@ -200,7 +198,7 @@ void TestPadV2Quantized(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_quantized, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(pad_data, pad_dims), + CreateTensor(pad_data, pad_dims), CreateQuantizedTensor(&pad_value, &pad_value_quantized, pad_value_dims, pad_value_scale, pad_value_zero_point), CreateQuantizedTensor(output_data, output_dims, output_scale, @@ -211,8 +209,8 @@ void TestPadV2Quantized(const int* input_dims_data, const float* input_data, tensors[2].params.scale = pad_value_scale; tensors[3].params.scale = output_scale; - tflite::AsymmetricQuantize(golden, golden_quantized, output_dims_count, - output_scale, output_zero_point); + tflite::Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); TF_LITE_MICRO_EXPECT_EQ( expected_status, ValidatePadV2Goldens(tensors, tensors_size, golden_quantized, output_data, diff --git a/tensorflow/lite/micro/kernels/pooling_test.cc b/tensorflow/lite/micro/kernels/pooling_test.cc index 9782b49ad98..2f384597e7c 100644 --- a/tensorflow/lite/micro/kernels/pooling_test.cc +++ b/tensorflow/lite/micro/kernels/pooling_test.cc @@ -73,8 +73,8 @@ void TestAveragePoolFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; const TfLiteRegistration registration = @@ -131,8 +131,8 @@ void TestMaxPoolFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; const TfLiteRegistration registration = diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index 3a0b10a0d94..92acecf052a 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -57,9 +57,9 @@ void TestPreluFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(alpha_data, alpha_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(alpha_data, alpha_dims), + CreateTensor(output_data, output_dims), }; ValidatePreluGoldens(tensors, tensors_size, expected_output_data, @@ -93,8 +93,8 @@ void TestPreluQuantized(const int* input_dims_data, const float* input_data, output_zero_point), }; - AsymmetricQuantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); ValidatePreluGoldens(tensors, tensors_size, golden_quantized, output_dims_count, output_data); diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc index b630fb53bca..fdcf65f9ce4 100644 --- a/tensorflow/lite/micro/kernels/quantize_test.cc +++ b/tensorflow/lite/micro/kernels/quantize_test.cc @@ -43,7 +43,7 @@ void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); // Use reference quantization from test utils to compare against op output. - AsymmetricQuantize(golden, golden_quantized, output_len, scale, zero_point); + Quantize(golden, golden_quantized, output_len, scale, zero_point); for (int i = 0; i < output_len; ++i) { TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); } @@ -71,7 +71,7 @@ void TestQuantizeFloat(const int* input_dims_data, const float* input_data, // 1 input, 1 output. constexpr int tensors_size = 2; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), + CreateTensor(input_data, input_dims), output_tensor, }; diff --git a/tensorflow/lite/micro/kernels/reduce_test.cc b/tensorflow/lite/micro/kernels/reduce_test.cc index fdb8fe95466..3666bc0b2fb 100644 --- a/tensorflow/lite/micro/kernels/reduce_test.cc +++ b/tensorflow/lite/micro/kernels/reduce_test.cc @@ -106,9 +106,9 @@ void TestMeanFloatInput4D(const int* input_dims_data, const float* input_data, constexpr int tensors_size = num_of_inputs + num_of_outputs; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(axis_data, axis_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(axis_data, axis_dims), + CreateTensor(output_data, output_dims), }; TF_LITE_MICRO_EXPECT_EQ( @@ -133,9 +133,9 @@ void TestReduceOpFloat(const int* input_dims_data, const float* input_data, constexpr int tensors_size = num_of_inputs + num_of_outputs; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(axis_data, axis_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(axis_data, axis_dims), + CreateTensor(output_data, output_dims), }; TF_LITE_MICRO_EXPECT_EQ( @@ -165,15 +165,14 @@ void TestReduceOpQuantized( TfLiteTensor tensors[] = { CreateQuantizedTensor(input_data, input_data_quant, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(axis_data, axis_dims), + CreateTensor(axis_data, axis_dims), CreateQuantizedTensor(output_data_quant, output_dims, output_scale, output_zero_point), }; // Quantize expected output - tflite::AsymmetricQuantize(expected_output_data, expected_output_data_quant, - output_dims_count, output_scale, - output_zero_point); + tflite::Quantize(expected_output_data, expected_output_data_quant, + output_dims_count, output_scale, output_zero_point); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, @@ -204,15 +203,14 @@ void TestMeanOpQuantized(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[] = { CreateQuantizedTensor(input_data, input_data_quant, input_dims, input_scale, input_zero_point), - CreateInt32Tensor(axis_data, axis_dims), + CreateTensor(axis_data, axis_dims), CreateQuantizedTensor(output_data_quant, output_dims, output_scale, output_zero_point), }; // Quantize expected output - tflite::AsymmetricQuantize(expected_output_data, expected_output_data_quant, - output_dims_count, output_scale, - output_zero_point); + tflite::Quantize(expected_output_data, expected_output_data_quant, + output_dims_count, output_scale, output_zero_point); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc index 48d1956f1c8..9e1da3ca51d 100644 --- a/tensorflow/lite/micro/kernels/reshape_test.cc +++ b/tensorflow/lite/micro/kernels/reshape_test.cc @@ -121,9 +121,9 @@ void TestReshape(const int* input_dims_data, const float* input_data, TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - TfLiteTensor input_tensor = CreateFloatTensor(input_data, input_dims); - TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); - TfLiteTensor output_tensor = CreateFloatTensor(output_data, output_dims); + TfLiteTensor input_tensor = CreateTensor(input_data, input_dims); + TfLiteTensor shape_tensor = CreateTensor(shape_data, shape_dims); + TfLiteTensor output_tensor = CreateTensor(output_data, output_dims); TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor, expected_output, expected_output_len, expected_dims, @@ -144,7 +144,7 @@ void TestReshapeQuantized(const int* input_dims_data, const T* input_data, TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteTensor input_tensor = CreateQuantizedTensor( input_data, input_dims, /*scale=*/1.f, /*zero_point=*/0); - TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); + TfLiteTensor shape_tensor = CreateTensor(shape_data, shape_dims); TfLiteTensor output_tensor = CreateQuantizedTensor( output_data, output_dims, /*scale=*/1.f, /*zero_point=*/0); @@ -213,14 +213,12 @@ TF_LITE_MICRO_TEST(ReshapeWithInvalidShapeShouldFail) { TfLiteIntArray* input_dims = tflite::testing::IntArrayFromInts(input_dims_data); const float input_data[] = {3.0f}; - auto input_tensor = - tflite::testing::CreateFloatTensor(input_data, input_dims); + auto input_tensor = tflite::testing::CreateTensor(input_data, input_dims); float output_data[4]; int output_dims_data[6] = {2, 2, 1, 2, 2, 1}; TfLiteIntArray* output_dims = tflite::testing::IntArrayFromInts(output_dims_data); - auto output_tensor = - tflite::testing::CreateFloatTensor(output_data, output_dims); + auto output_tensor = tflite::testing::CreateTensor(output_data, output_dims); const int expected_output[] = {}; const int expected_output_len = 0; const int expected_dims[] = {}; @@ -328,25 +326,24 @@ TF_LITE_MICRO_TEST(ReshapeWithScalarOutputShouldSucceed) { // Some old models specify '[0]' as the new shape, indicating that both input // and output are scalars. TF_LITE_MICRO_TEST(ReshapeWithLegacyScalarOutputShouldSucceed) { - using tflite::testing::CreateFloatTensor; + using tflite::testing::CreateTensor; using tflite::testing::IntArrayFromInts; int input_dims_data[] = {1, 1}; TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); const float input_data[] = {3.0f}; - auto input_tensor = CreateFloatTensor(input_data, input_dims); + auto input_tensor = CreateTensor(input_data, input_dims); float output_data[1]; int output_dims_data[2] = {1, 0}; TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - auto output_tensor = CreateFloatTensor(output_data, output_dims); + auto output_tensor = CreateTensor(output_data, output_dims); int shape_dims_data[] = {1, 0}; TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); const int32_t shape_data[] = {0}; - auto shape_tensor = - tflite::testing::CreateInt32Tensor(shape_data, shape_dims); + auto shape_tensor = tflite::testing::CreateTensor(shape_data, shape_dims); const float expected_output_with_shape[] = {}; const int expected_output_with_shape_len = 0; const float expected_output_no_shape[] = {3}; diff --git a/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc b/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc index 9362a89a3ed..f1af763d9bb 100644 --- a/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc @@ -27,7 +27,7 @@ using uint8_t = std::uint8_t; using int32_t = std::int32_t; TfLiteTensor TestCreateTensor(const float* data, TfLiteIntArray* dims) { - return CreateFloatTensor(data, dims); + return CreateTensor(data, dims); } TfLiteTensor TestCreateTensor(const uint8_t* data, TfLiteIntArray* dims) { @@ -59,7 +59,7 @@ void TestResizeNearestNeighbor(const int* input_dims_data, const T* input_data, constexpr int tensors_size = 3; TfLiteTensor tensors[tensors_size] = { TestCreateTensor(input_data, input_dims), - CreateInt32Tensor(expected_size_data, expected_size_dims), + CreateTensor(expected_size_data, expected_size_dims), TestCreateTensor(output_data, output_dims), }; diff --git a/tensorflow/lite/micro/kernels/round_test.cc b/tensorflow/lite/micro/kernels/round_test.cc index 8067d8cd091..412ecf5b539 100644 --- a/tensorflow/lite/micro/kernels/round_test.cc +++ b/tensorflow/lite/micro/kernels/round_test.cc @@ -33,8 +33,8 @@ void TestRound(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; diff --git a/tensorflow/lite/micro/kernels/shape_test.cc b/tensorflow/lite/micro/kernels/shape_test.cc index 7c7e0db82db..5bfdee5bb10 100755 --- a/tensorflow/lite/micro/kernels/shape_test.cc +++ b/tensorflow/lite/micro/kernels/shape_test.cc @@ -55,8 +55,8 @@ void TestShape(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(output_data, output_dims, true), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims, true), }; ValidateShape(tensors, tensors_size, output_data, expected_output_data, diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc index 21fc1074760..bfc1c4b61ff 100644 --- a/tensorflow/lite/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/micro/kernels/softmax_test.cc @@ -281,8 +281,8 @@ void TestSoftmaxFloat(const int* input_dims_data, const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; ValidateSoftmaxGoldens(tensors, tensors_size, output_data, @@ -310,8 +310,8 @@ void TestSoftmaxQuantized(const int* input_dims_data, const float* input_data, output_zero_point), }; - AsymmetricQuantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + Quantize(golden, golden_quantized, output_dims_count, output_scale, + output_zero_point); ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized, output_dims_count, tolerance); diff --git a/tensorflow/lite/micro/kernels/split_test.cc b/tensorflow/lite/micro/kernels/split_test.cc index cd9a90804e0..b5d038cdc3a 100644 --- a/tensorflow/lite/micro/kernels/split_test.cc +++ b/tensorflow/lite/micro/kernels/split_test.cc @@ -42,10 +42,9 @@ void TestSplitTwoOutputsFloat( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(axis_data, axis_dims), - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output1_data, output1_dims), - CreateFloatTensor(output2_data, output2_dims)}; + CreateTensor(axis_data, axis_dims), CreateTensor(input_data, input_dims), + CreateTensor(output1_data, output1_dims), + CreateTensor(output2_data, output2_dims)}; // Currently only support constant axis tensor. tensors[0].allocation_type = kTfLiteMmapRo; @@ -104,12 +103,12 @@ void TestSplitFourOutputsFloat( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(axis_data, axis_dims), - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output1_data, output1_dims), - CreateFloatTensor(output2_data, output2_dims), - CreateFloatTensor(output3_data, output1_dims), - CreateFloatTensor(output4_data, output1_dims)}; + CreateTensor(axis_data, axis_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output1_data, output1_dims), + CreateTensor(output2_data, output2_dims), + CreateTensor(output3_data, output1_dims), + CreateTensor(output4_data, output1_dims)}; // Currently only support constant axis tensor. tensors[0].allocation_type = kTfLiteMmapRo; @@ -171,7 +170,7 @@ void TestSplitTwoOutputsQuantized( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(axis_data, axis_dims), + CreateTensor(axis_data, axis_dims), CreateQuantizedTensor(input_data, input_dims, 0, 10), CreateQuantizedTensor(output1_data, output1_dims, 0, 10), CreateQuantizedTensor(output2_data, output2_dims, 0, 10)}; @@ -227,10 +226,9 @@ void TestSplitTwoOutputsQuantized32( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(axis_data, axis_dims), - CreateInt32Tensor(input_data, input_dims), - CreateInt32Tensor(output1_data, output1_dims), - CreateInt32Tensor(output2_data, output2_dims)}; + CreateTensor(axis_data, axis_dims), CreateTensor(input_data, input_dims), + CreateTensor(output1_data, output1_dims), + CreateTensor(output2_data, output2_dims)}; // Currently only support constant axis tensor. tensors[0].allocation_type = kTfLiteMmapRo; diff --git a/tensorflow/lite/micro/kernels/split_v_test.cc b/tensorflow/lite/micro/kernels/split_v_test.cc index 6a41b2b1985..06c90cb69e3 100755 --- a/tensorflow/lite/micro/kernels/split_v_test.cc +++ b/tensorflow/lite/micro/kernels/split_v_test.cc @@ -63,13 +63,13 @@ void TestSplitVFloat(const int* input_dims_data, const float* input_data, // then come outputs TfLiteTensor tensors[tensors_size]; - tensors[0] = CreateFloatTensor(input_data, input_dims); - tensors[1] = CreateInt32Tensor(split_data, split_dims); - tensors[2] = CreateInt32Tensor(axis_data, axis_dims); + tensors[0] = CreateTensor(input_data, input_dims); + tensors[1] = CreateTensor(split_data, split_dims); + tensors[2] = CreateTensor(axis_data, axis_dims); // add output tensors for (int i = 0; i < N; i++) - tensors[3 + i] = CreateFloatTensor(output_tensors.data[i], output_dims[i]); + tensors[3 + i] = CreateTensor(output_tensors.data[i], output_dims[i]); tensors[2].allocation_type = kTfLiteMmapRo; tensors[1].allocation_type = kTfLiteMmapRo; diff --git a/tensorflow/lite/micro/kernels/strided_slice_test.cc b/tensorflow/lite/micro/kernels/strided_slice_test.cc index a6de5bd1e59..7f8446001eb 100644 --- a/tensorflow/lite/micro/kernels/strided_slice_test.cc +++ b/tensorflow/lite/micro/kernels/strided_slice_test.cc @@ -74,11 +74,11 @@ void TestStridedSliceFloat(const int* input_shape, const int* begin_shape, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateInt32Tensor(begin_data, begin_dims), - CreateInt32Tensor(end_data, end_dims), - CreateInt32Tensor(strides_data, strides_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(begin_data, begin_dims), + CreateTensor(end_data, end_dims), + CreateTensor(strides_data, strides_dims), + CreateTensor(output_data, output_dims), }; ValidateStridedSliceGoldens(tensors, tensors_size, expected_output, @@ -106,9 +106,9 @@ void TestStridedSliceQuantized( std::numeric_limits::max() + std::numeric_limits::min() / 2; TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_dims, 1.0, zero_point), - CreateInt32Tensor(begin_data, begin_dims), - CreateInt32Tensor(end_data, end_dims), - CreateInt32Tensor(strides_data, strides_dims), + CreateTensor(begin_data, begin_dims), + CreateTensor(end_data, end_dims), + CreateTensor(strides_data, strides_dims), CreateQuantizedTensor(output_data, output_dims, 1.0, zero_point), }; diff --git a/tensorflow/lite/micro/kernels/sub_test.cc b/tensorflow/lite/micro/kernels/sub_test.cc index 1cc0c80527b..badca6e14e4 100644 --- a/tensorflow/lite/micro/kernels/sub_test.cc +++ b/tensorflow/lite/micro/kernels/sub_test.cc @@ -99,9 +99,9 @@ void TestSubFloat(const int* input1_dims_data, const float* input1_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input1_data, input1_dims), - CreateFloatTensor(input2_data, input2_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input1_data, input1_dims), + CreateTensor(input2_data, input2_dims), + CreateTensor(output_data, output_dims), }; ValidateSubGoldens(tensors, tensors_size, expected_output, output_data, @@ -135,9 +135,8 @@ void TestSubQuantized(const int* input1_dims_data, const float* input1_data, tflite::testing::CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point), }; - tflite::AsymmetricQuantize(golden, golden_quantized, - ElementCount(*output_dims), output_scale, - output_zero_point); + tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims), + output_scale, output_zero_point); ValidateSubGoldens(tensors, tensors_size, golden_quantized, output_data, ElementCount(*output_dims), activation); diff --git a/tensorflow/lite/micro/kernels/svdf_test.cc b/tensorflow/lite/micro/kernels/svdf_test.cc index 771ff66a4b7..775477b9710 100644 --- a/tensorflow/lite/micro/kernels/svdf_test.cc +++ b/tensorflow/lite/micro/kernels/svdf_test.cc @@ -565,13 +565,13 @@ void TestSVDF(const int batch_size, const int num_units, const int input_size, const int tensor_count = 6; // 5 inputs, 1 output TfLiteTensor tensors[] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(feature_weights_data, feature_weights_dims), - CreateFloatTensor(time_weights_data, time_weights_dims), - CreateFloatTensor(bias_data, bias_dims), - CreateFloatTensor(activation_state_data, activation_state_dims, - /*is_variable=*/true), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(feature_weights_data, feature_weights_dims), + CreateTensor(time_weights_data, time_weights_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(activation_state_data, activation_state_dims, + /*is_variable=*/true), + CreateTensor(output_data, output_dims), }; ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors, @@ -640,12 +640,10 @@ inline void TestIntegerSVDF( CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point)}; - tflite::AsymmetricQuantize(golden_output, golden_output_quantized, - golden_output_len, output_scale, - output_zero_point); - tflite::AsymmetricQuantize(input_sequences_data, input_sequences_quantized, - input_sequences_len, input_scale, - input_zero_point); + tflite::Quantize(golden_output, golden_output_quantized, golden_output_len, + output_scale, output_zero_point); + tflite::Quantize(input_sequences_data, input_sequences_quantized, + input_sequences_len, input_scale, input_zero_point); ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors, tensor_count, activation, input_sequences_quantized, diff --git a/tensorflow/lite/micro/kernels/tanh_test.cc b/tensorflow/lite/micro/kernels/tanh_test.cc index 4a4f94bc2e5..52a03aedcff 100644 --- a/tensorflow/lite/micro/kernels/tanh_test.cc +++ b/tensorflow/lite/micro/kernels/tanh_test.cc @@ -77,8 +77,8 @@ void TestTanhFloat(const int input_dims_data[], const float* input_data, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims), + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), }; int inputs_array_data[] = {1, 0}; @@ -113,9 +113,8 @@ void TestTanhQuantized(const int input_dims_data[], const float* input_data, TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); const int output_elements_count = ElementCount(*output_dims); - tflite::AsymmetricQuantize(expected_output_data, expected_output_quantized, - output_elements_count, output_scale, - output_zero_point); + tflite::Quantize(expected_output_data, expected_output_quantized, + output_elements_count, output_scale, output_zero_point); constexpr int inputs_size = 1; constexpr int outputs_size = 1; diff --git a/tensorflow/lite/micro/kernels/unpack_test.cc b/tensorflow/lite/micro/kernels/unpack_test.cc index b5c17bd8d2f..95846651cd0 100644 --- a/tensorflow/lite/micro/kernels/unpack_test.cc +++ b/tensorflow/lite/micro/kernels/unpack_test.cc @@ -41,10 +41,10 @@ void TestUnpackThreeOutputsFloat( constexpr int output_size = 3; constexpr int tensors_size = input_size + output_size; TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output1_data, output1_dims), - CreateFloatTensor(output2_data, output2_dims), - CreateFloatTensor(output3_data, output3_dims)}; + CreateTensor(input_data, input_dims), + CreateTensor(output1_data, output1_dims), + CreateTensor(output2_data, output2_dims), + CreateTensor(output3_data, output3_dims)}; // Place a unique value in the uninitialized output buffer. for (int i = 0; i < output1_dims_count; ++i) { @@ -102,9 +102,8 @@ void TestUnpackOneOutputFloat(const int* input_dims_data, constexpr int input_size = 1; constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; - TfLiteTensor tensors[tensors_size] = { - CreateFloatTensor(input_data, input_dims), - CreateFloatTensor(output_data, output_dims)}; + TfLiteTensor tensors[tensors_size] = {CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims)}; // Place a unique value in the uninitialized output buffer. for (int i = 0; i < output_dims_count; ++i) { @@ -222,10 +221,10 @@ void TestUnpackThreeOutputsQuantized32( constexpr int output_size = 3; constexpr int tensors_size = input_size + output_size; TfLiteTensor tensors[tensors_size] = { - CreateInt32Tensor(input_data, input_dims), - CreateInt32Tensor(output1_data, output1_dims), - CreateInt32Tensor(output2_data, output2_dims), - CreateInt32Tensor(output3_data, output3_dims)}; + CreateTensor(input_data, input_dims), + CreateTensor(output1_data, output1_dims), + CreateTensor(output2_data, output2_dims), + CreateTensor(output3_data, output3_dims)}; // Place a unique value in the uninitialized output buffer. for (int i = 0; i < output1_dims_count; ++i) { diff --git a/tensorflow/lite/micro/memory_helpers_test.cc b/tensorflow/lite/micro/memory_helpers_test.cc index 5000a880638..566ad369849 100644 --- a/tensorflow/lite/micro/memory_helpers_test.cc +++ b/tensorflow/lite/micro/memory_helpers_test.cc @@ -180,11 +180,11 @@ TF_LITE_MICRO_TEST(TestAllocateOutputDimensionsFromInput) { const int input1_dims[] = {1, 1}; const int input2_dims[] = {kDimsLen, 5, 5, 5, 5}; int output_dims[] = {0, 0, 0, 0, 0}; - TfLiteTensor input_tensor1 = tflite::testing::CreateInt32Tensor( + TfLiteTensor input_tensor1 = tflite::testing::CreateTensor( nullptr, tflite::testing::IntArrayFromInts(input1_dims)); - TfLiteTensor input_tensor2 = tflite::testing::CreateInt32Tensor( + TfLiteTensor input_tensor2 = tflite::testing::CreateTensor( nullptr, tflite::testing::IntArrayFromInts(input2_dims)); - TfLiteTensor output_tensor = tflite::testing::CreateInt32Tensor( + TfLiteTensor output_tensor = tflite::testing::CreateTensor( nullptr, tflite::testing::IntArrayFromInts(output_dims)); TfLiteContext context; // Only need to allocate space for output_tensor.dims. Use a simple diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 770921b4234..71a8493289f 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -621,13 +621,6 @@ MicroAllocator::~MicroAllocator() {} MicroAllocator* MicroAllocator::Create(uint8_t* tensor_arena, size_t arena_size, ErrorReporter* error_reporter) { uint8_t* aligned_arena = AlignPointerUp(tensor_arena, kBufferAlignment); - if (aligned_arena != tensor_arena) { - TF_LITE_REPORT_ERROR( - error_reporter, - "%d bytes lost due to alignment. To avoid this loss, please make sure " - "the tensor_arena is 16 bytes aligned.", - aligned_arena - tensor_arena); - } size_t aligned_arena_size = tensor_arena + arena_size - aligned_arena; return Create(SimpleMemoryAllocator::Create(error_reporter, aligned_arena, aligned_arena_size), @@ -818,6 +811,8 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_, &(node_and_registrations[i].registration)); if (status != kTfLiteOk) { + // TODO(b/171278094): Use the GetBuiltinCode method in the schema utilitly + // to get builtin code from op code. TF_LITE_REPORT_ERROR(error_reporter_, "Failed to get registration from op code %s\n ", EnumNameBuiltinOperator(opcode->builtin_code())); diff --git a/tensorflow/lite/micro/micro_utils.cc b/tensorflow/lite/micro/micro_utils.cc index ff885fa04ff..96152364c25 100644 --- a/tensorflow/lite/micro/micro_utils.cc +++ b/tensorflow/lite/micro/micro_utils.cc @@ -15,34 +15,15 @@ limitations under the License. #include "tensorflow/lite/micro/micro_utils.h" -#include -#include -#include +#include +#include +#include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/op_macros.h" namespace tflite { -namespace { - -static const uint8_t kAsymmetricUInt8Min = 0; -static const uint8_t kAsymmetricUInt8Max = UINT8_MAX; -static const uint8_t kSymmetricUInt8Min = 1; -static const uint8_t kSymmetricUInt8Max = UINT8_MAX; -static const int8_t kAsymmetricInt8Min = INT8_MIN; -static const int8_t kAsymmetricInt8Max = INT8_MAX; -static const int kSymmetricInt8Scale = kAsymmetricInt8Max; - -static const int16_t kAsymmetricInt16Min = INT16_MIN; -static const int16_t kAsymmetricInt16Max = INT16_MAX; -static const int kSymmetricInt16Scale = kAsymmetricInt16Max; - -static const int32_t kAsymmetricInt32Max = INT32_MAX; -static const int kSymmetricInt32Scale = kAsymmetricInt32Max; - -} // namespace - int ElementCount(const TfLiteIntArray& dims) { int result = 1; for (int i = 0; i < dims.size; ++i) { @@ -51,109 +32,6 @@ int ElementCount(const TfLiteIntArray& dims) { return result; } -// Converts a float value into an unsigned eight-bit quantized value. -uint8_t FloatToAsymmetricQuantizedUInt8(const float value, const float scale, - const int zero_point) { - int32_t result = round(value / scale) + zero_point; - if (result < kAsymmetricUInt8Min) { - result = kAsymmetricUInt8Min; - } - if (result > kAsymmetricUInt8Max) { - result = kAsymmetricUInt8Max; - } - return result; -} - -uint8_t FloatToSymmetricQuantizedUInt8(const float value, const float scale) { - int32_t result = round(value / scale); - if (result < kSymmetricUInt8Min) { - result = kSymmetricUInt8Min; - } - if (result > kSymmetricUInt8Max) { - result = kSymmetricUInt8Max; - } - return result; -} - -int8_t FloatToAsymmetricQuantizedInt8(const float value, const float scale, - const int zero_point) { - int32_t result = round(value / scale) + zero_point; - if (result < kAsymmetricInt8Min) { - result = kAsymmetricInt8Min; - } - if (result > kAsymmetricInt8Max) { - result = kAsymmetricInt8Max; - } - return result; -} - -int16_t FloatToAsymmetricQuantizedInt16(const float value, const float scale, - const int zero_point) { - int32_t result = round(value / scale) + zero_point; - if (result < kAsymmetricInt16Min) { - result = kAsymmetricInt16Min; - } - if (result > kAsymmetricInt16Max) { - result = kAsymmetricInt16Max; - } - return result; -} - -int8_t FloatToSymmetricQuantizedInt8(const float value, const float scale) { - return FloatToAsymmetricQuantizedInt8(value, scale, 0.0f); -} - -int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale) { - float quantized = round(value / scale); - if (static_cast(quantized) > INT_MAX) { - quantized = static_cast(INT_MAX); - } else if (quantized < INT_MIN) { - quantized = static_cast INT_MIN; - } - - return static_cast(quantized); -} - -void AsymmetricQuantize(const float* input, int8_t* output, int num_elements, - float scale, int zero_point) { - for (int i = 0; i < num_elements; i++) { - output[i] = FloatToAsymmetricQuantizedInt8(input[i], scale, zero_point); - } -} - -void AsymmetricQuantize(const float* input, uint8_t* output, int num_elements, - float scale, int zero_point) { - for (int i = 0; i < num_elements; i++) { - output[i] = FloatToAsymmetricQuantizedUInt8(input[i], scale, zero_point); - } -} - -void AsymmetricQuantize(const float* input, int16_t* output, int num_elements, - float scale, int zero_point) { - for (int i = 0; i < num_elements; i++) { - output[i] = FloatToAsymmetricQuantizedInt16(input[i], scale, zero_point); - } -} - -void SymmetricQuantize(const float* input, int32_t* output, int num_elements, - float scale) { - for (int i = 0; i < num_elements; i++) { - output[i] = FloatToSymmetricQuantizedInt32(input[i], scale); - } -} - -void SymmetricPerChannelQuantize(const float* input, int32_t* output, - int num_elements, int num_channels, - float* scales) { - int elements_per_channel = num_elements / num_channels; - for (int i = 0; i < num_channels; i++) { - for (int j = 0; j < elements_per_channel; j++) { - output[i * elements_per_channel + j] = FloatToSymmetricQuantizedInt32( - input[i * elements_per_channel + j], scales[i]); - } - } -} - void SignedSymmetricPerChannelQuantize(const float* values, TfLiteIntArray* dims, int quantized_dimension, @@ -186,94 +64,17 @@ void SignedSymmetricPerChannelQuantize(const float* values, max = fmaxf(max, values[idx]); } scaling_factors[channel] = - fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale; + fmaxf(fabs(min), fabs(max)) / std::numeric_limits::max(); for (int i = 0; i < per_channel_size; i++) { int idx = channel * channel_stride + i * stride; const int32_t quantized_value = static_cast(roundf(values[idx] / scaling_factors[channel])); // Clamp: just in case some odd numeric offset. - quantized_values[idx] = fminf( - kSymmetricInt8Scale, fmaxf(-kSymmetricInt8Scale, quantized_value)); + quantized_values[idx] = + fminf(std::numeric_limits::max(), + fmaxf(std::numeric_limits::min() + 1, quantized_value)); } } } -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int8_t* quantized_values, float* scaling_factor) { - int input_size = ElementCount(*dims); - - float min = 0; - float max = 0; - for (int i = 0; i < input_size; i++) { - min = fminf(min, values[i]); - max = fmaxf(max, values[i]); - } - *scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale; - for (int i = 0; i < input_size; i++) { - const int32_t quantized_value = - static_cast(roundf(values[i] / *scaling_factor)); - // Clamp: just in case some odd numeric offset. - quantized_values[i] = fminf(kSymmetricInt8Scale, - fmaxf(-kSymmetricInt8Scale, quantized_value)); - } -} - -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int16_t* quantized_values, float* scaling_factor) { - int input_size = ElementCount(*dims); - - float min = 0; - float max = 0; - for (int i = 0; i < input_size; i++) { - min = fminf(min, values[i]); - max = fmaxf(max, values[i]); - } - *scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt16Scale; - for (int i = 0; i < input_size; i++) { - const int32_t quantized_value = - static_cast(roundf(values[i] / *scaling_factor)); - // Clamp: just in case some odd numeric offset. - quantized_values[i] = fminf(kSymmetricInt16Scale, - fmaxf(-kSymmetricInt16Scale, quantized_value)); - } -} - -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int32_t* quantized_values, float* scaling_factor) { - int input_size = ElementCount(*dims); - - float min = 0; - float max = 0; - for (int i = 0; i < input_size; i++) { - min = fminf(min, values[i]); - max = fmaxf(max, values[i]); - } - - *scaling_factor = - fmaxf(fabs(min), fabs(max)) / static_cast(kSymmetricInt32Scale); - for (int i = 0; i < input_size; i++) { - const int32_t quantized_value = - static_cast(roundf(values[i] / *scaling_factor)); - // Clamp: just in case some odd numeric offset. - quantized_values[i] = fminf( - static_cast(kSymmetricInt32Scale), - fmaxf(static_cast(-kSymmetricInt32Scale), quantized_value)); - } -} - -void SymmetricQuantize(const float* values, TfLiteIntArray* dims, - uint8_t* quantized_values, float* scaling_factor) { - SignedSymmetricQuantize(values, dims, - reinterpret_cast(quantized_values), - scaling_factor); -} - -void SymmetricDequantize(const int8_t* values, const int size, - const float dequantization_scale, - float* dequantized_values) { - for (int i = 0; i < size; ++i) { - dequantized_values[i] = values[i] * dequantization_scale; - } -} - } // namespace tflite diff --git a/tensorflow/lite/micro/micro_utils.h b/tensorflow/lite/micro/micro_utils.h index 24aebad8a78..b9a3121a1f3 100644 --- a/tensorflow/lite/micro/micro_utils.h +++ b/tensorflow/lite/micro/micro_utils.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_UTILS_H_ #define TENSORFLOW_LITE_MICRO_MICRO_UTILS_H_ -#include +#include +#include +#include #include "tensorflow/lite/c/common.h" @@ -26,23 +28,28 @@ namespace tflite { int ElementCount(const TfLiteIntArray& dims); -uint8_t FloatToAsymmetricQuantizedUInt8(const float value, const float scale, - const int zero_point); +// Converts a float value into a quantized value. Note that large values (close +// to max int and min int) may see significant error due to a lack of floating +// point granularity for large values. +template +T FloatToQuantizedType(const float value, const float scale, int zero_point) { + int32_t result = round(value / scale) + zero_point; + result = + std::max(static_cast(std::numeric_limits::min()), result); + result = + std::min(static_cast(std::numeric_limits::max()), result); + return result; +} -uint8_t FloatToSymmetricQuantizedUInt8(const float value, const float scale); - -int8_t FloatToAsymmetricQuantizedInt8(const float value, const float scale, - const int zero_point); - -int16_t FloatToAsymmetricQuantizedInt16(const float value, const float scale, - const int zero_point); - -int8_t FloatToSymmetricQuantizedInt8(const float value, const float scale); - -// Converts a float value into a signed thirty-two-bit quantized value. Note -// that values close to max int and min int may see significant error due to -// a lack of floating point granularity for large values. -int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale); +template +T FloatToSymmetricQuantizedType(const float value, const float scale) { + int32_t result = round(value / scale); + result = + std::max(static_cast(std::numeric_limits::min() + 1), result); + result = + std::min(static_cast(std::numeric_limits::max()), result); + return result; +} // Helper methods to quantize arrays of floats to the desired format. // @@ -55,22 +62,34 @@ int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale); // // The per-op quantization spec can be found here: // https://www.tensorflow.org/lite/performance/quantization_spec +template +void Quantize(const float* input, T* output, int num_elements, float scale, + int zero_point) { + for (int i = 0; i < num_elements; i++) { + output[i] = FloatToQuantizedType(input[i], scale, zero_point); + } +} -void AsymmetricQuantize(const float* input, int8_t* output, int num_elements, - float scale, int zero_point = 0); +template +void SymmetricQuantize(const float* input, T* output, int num_elements, + float scale) { + for (int i = 0; i < num_elements; i++) { + output[i] = FloatToSymmetricQuantizedType(input[i], scale); + } +} -void AsymmetricQuantize(const float* input, uint8_t* output, int num_elements, - float scale, int zero_point = 128); - -void AsymmetricQuantize(const float* input, int16_t* output, int num_elements, - float scale, int zero_point = 0); - -void SymmetricQuantize(const float* input, int32_t* output, int num_elements, - float scale); - -void SymmetricPerChannelQuantize(const float* input, int32_t* output, +template +void SymmetricPerChannelQuantize(const float* input, T* output, int num_elements, int num_channels, - float* scales); + float* scales) { + int elements_per_channel = num_elements / num_channels; + for (int i = 0; i < num_channels; i++) { + for (int j = 0; j < elements_per_channel; j++) { + output[i * elements_per_channel + j] = FloatToSymmetricQuantizedType( + input[i * elements_per_channel + j], scales[i]); + } + } +} void SignedSymmetricPerChannelQuantize(const float* values, TfLiteIntArray* dims, @@ -78,30 +97,35 @@ void SignedSymmetricPerChannelQuantize(const float* values, int8_t* quantized_values, float* scaling_factor); -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int8_t* quantized_values, float* scaling_factor); +// Quantizes inputs based on the values provided, choosing the smallest range +// which includes all input values. +template +void SymmetricQuantizeCalculateScales(const float* values, TfLiteIntArray* dims, + T* output, float* scale) { + int input_size = ElementCount(*dims); -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int16_t* quantized_values, float* scaling_factor); - -void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims, - int32_t* quantized_values, float* scaling_factor); - -void SymmetricQuantize(const float* values, TfLiteIntArray* dims, - uint8_t* quantized_values, float* scaling_factor); - -void SymmetricDequantize(const int8_t* values, const int size, - const float dequantization_scale, - float* dequantized_values); + float min = 0; + float max = 0; + for (int i = 0; i < input_size; i++) { + min = fminf(min, values[i]); + max = fmaxf(max, values[i]); + } + *scale = fmaxf(std::abs(min), std::abs(max)) / std::numeric_limits::max(); + for (int i = 0; i < input_size; i++) { + const int32_t quantized_value = + static_cast(roundf(values[i] / *scale)); + // Clamp: just in case some odd numeric offset. + quantized_value = fminf(std::numeric_limits::max(), quantized_value); + quantized_value = fmaxf(std::numeric_limits::min() + 1, quantized_value); + output[i] = quantized_value; + } +} template -void AsymmetricDequantize(const T* values, const int size, - const float dequantization_scale, - int dequantization_zero_point, - float* dequantized_values) { +void Dequantize(const T* values, const int size, const float scale, + int zero_point, float* dequantized_values) { for (int i = 0; i < size; ++i) { - dequantized_values[i] = - (values[i] - dequantization_zero_point) * dequantization_scale; + dequantized_values[i] = (values[i] - zero_point) * scale; } } diff --git a/tensorflow/lite/micro/micro_utils_test.cc b/tensorflow/lite/micro/micro_utils_test.cc index 7aa31130595..d74004eacee 100644 --- a/tensorflow/lite/micro/micro_utils_test.cc +++ b/tensorflow/lite/micro/micro_utils_test.cc @@ -20,63 +20,68 @@ limitations under the License. TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedUInt8Test) { - using tflite::FloatToAsymmetricQuantizedUInt8; + using tflite::FloatToQuantizedType; // [0, 127.5] -> zero_point=0, scale=0.5 - TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedUInt8(0, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(254, FloatToAsymmetricQuantizedUInt8(127, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(255, FloatToAsymmetricQuantizedUInt8(127.5, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(0, FloatToQuantizedType(0, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(254, FloatToQuantizedType(127, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(255, FloatToQuantizedType(127.5, 0.5, 0)); // [-10, 245] -> zero_point=10, scale=1.0 - TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedUInt8(-10, 1.0, 10)); - TF_LITE_MICRO_EXPECT_EQ(1, FloatToAsymmetricQuantizedUInt8(-9, 1.0, 10)); - TF_LITE_MICRO_EXPECT_EQ(128, FloatToAsymmetricQuantizedUInt8(118, 1.0, 10)); - TF_LITE_MICRO_EXPECT_EQ(253, FloatToAsymmetricQuantizedUInt8(243, 1.0, 10)); - TF_LITE_MICRO_EXPECT_EQ(254, FloatToAsymmetricQuantizedUInt8(244, 1.0, 10)); - TF_LITE_MICRO_EXPECT_EQ(255, FloatToAsymmetricQuantizedUInt8(245, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(0, FloatToQuantizedType(-10, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(1, FloatToQuantizedType(-9, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(128, FloatToQuantizedType(118, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(253, FloatToQuantizedType(243, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(254, FloatToQuantizedType(244, 1.0, 10)); + TF_LITE_MICRO_EXPECT_EQ(255, FloatToQuantizedType(245, 1.0, 10)); } TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedInt8Test) { - using tflite::FloatToAsymmetricQuantizedInt8; + using tflite::FloatToQuantizedType; // [-64, 63.5] -> zero_point=0, scale=0.5 - TF_LITE_MICRO_EXPECT_EQ(2, FloatToAsymmetricQuantizedInt8(1, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(4, FloatToAsymmetricQuantizedInt8(2, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(6, FloatToAsymmetricQuantizedInt8(3, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(-10, FloatToAsymmetricQuantizedInt8(-5, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(-128, FloatToAsymmetricQuantizedInt8(-64, 0.5, 0)); - TF_LITE_MICRO_EXPECT_EQ(127, FloatToAsymmetricQuantizedInt8(63.5, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(2, FloatToQuantizedType(1, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(4, FloatToQuantizedType(2, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(6, FloatToQuantizedType(3, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(-10, FloatToQuantizedType(-5, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(-128, FloatToQuantizedType(-64, 0.5, 0)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToQuantizedType(63.5, 0.5, 0)); // [-127, 128] -> zero_point=-1, scale=1.0 - TF_LITE_MICRO_EXPECT_EQ(0, FloatToAsymmetricQuantizedInt8(1, 1.0, -1)); - TF_LITE_MICRO_EXPECT_EQ(-1, FloatToAsymmetricQuantizedInt8(0, 1.0, -1)); - TF_LITE_MICRO_EXPECT_EQ(126, FloatToAsymmetricQuantizedInt8(127, 1.0, -1)); - TF_LITE_MICRO_EXPECT_EQ(127, FloatToAsymmetricQuantizedInt8(128, 1.0, -1)); - TF_LITE_MICRO_EXPECT_EQ(-127, FloatToAsymmetricQuantizedInt8(-126, 1.0, -1)); - TF_LITE_MICRO_EXPECT_EQ(-128, FloatToAsymmetricQuantizedInt8(-127, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(0, FloatToQuantizedType(1, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-1, FloatToQuantizedType(0, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(126, FloatToQuantizedType(127, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToQuantizedType(128, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-127, FloatToQuantizedType(-126, 1.0, -1)); + TF_LITE_MICRO_EXPECT_EQ(-128, FloatToQuantizedType(-127, 1.0, -1)); } TF_LITE_MICRO_TEST(FloatToSymmetricQuantizedInt8Test) { - using tflite::FloatToSymmetricQuantizedInt8; + using tflite::FloatToSymmetricQuantizedType; // [-64, 63.5] -> zero_point=0, scale=0.5 - TF_LITE_MICRO_EXPECT_EQ(2, FloatToSymmetricQuantizedInt8(1, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(4, FloatToSymmetricQuantizedInt8(2, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(6, FloatToSymmetricQuantizedInt8(3, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-10, FloatToSymmetricQuantizedInt8(-5, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-128, FloatToSymmetricQuantizedInt8(-64, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(127, FloatToSymmetricQuantizedInt8(63.5, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(2, FloatToSymmetricQuantizedType(1, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(4, FloatToSymmetricQuantizedType(2, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(6, FloatToSymmetricQuantizedType(3, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-10, FloatToSymmetricQuantizedType(-5, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-127, + FloatToSymmetricQuantizedType(-64, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(127, + FloatToSymmetricQuantizedType(63.5, 0.5)); // [-127, 128] -> zero_point=-1, scale=1.0 - TF_LITE_MICRO_EXPECT_EQ(1, FloatToSymmetricQuantizedInt8(1, 1.0)); - TF_LITE_MICRO_EXPECT_EQ(0, FloatToSymmetricQuantizedInt8(0, 1.0)); - TF_LITE_MICRO_EXPECT_EQ(127, FloatToSymmetricQuantizedInt8(127, 1.0)); - TF_LITE_MICRO_EXPECT_EQ(127, FloatToSymmetricQuantizedInt8(128, 1.0)); - TF_LITE_MICRO_EXPECT_EQ(-126, FloatToSymmetricQuantizedInt8(-126, 1.0)); - TF_LITE_MICRO_EXPECT_EQ(-127, FloatToSymmetricQuantizedInt8(-127, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(1, FloatToSymmetricQuantizedType(1, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(0, FloatToSymmetricQuantizedType(0, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToSymmetricQuantizedType(127, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(127, FloatToSymmetricQuantizedType(128, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(-126, + FloatToSymmetricQuantizedType(-126, 1.0)); + TF_LITE_MICRO_EXPECT_EQ(-127, + FloatToSymmetricQuantizedType(-127, 1.0)); } TF_LITE_MICRO_TEST(FloatToAsymmetricQuantizedInt32Test) { - using tflite::FloatToSymmetricQuantizedInt32; - TF_LITE_MICRO_EXPECT_EQ(0, FloatToSymmetricQuantizedInt32(0, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(2, FloatToSymmetricQuantizedInt32(1, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-2, FloatToSymmetricQuantizedInt32(-1, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(-100, FloatToSymmetricQuantizedInt32(-50, 0.5)); - TF_LITE_MICRO_EXPECT_EQ(100, FloatToSymmetricQuantizedInt32(50, 0.5)); + using tflite::FloatToSymmetricQuantizedType; + TF_LITE_MICRO_EXPECT_EQ(0, FloatToSymmetricQuantizedType(0, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(2, FloatToSymmetricQuantizedType(1, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-2, FloatToSymmetricQuantizedType(-1, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(-100, + FloatToSymmetricQuantizedType(-50, 0.5)); + TF_LITE_MICRO_EXPECT_EQ(100, FloatToSymmetricQuantizedType(50, 0.5)); } TF_LITE_MICRO_TEST(AsymmetricQuantizeInt8) { @@ -84,7 +89,7 @@ TF_LITE_MICRO_TEST(AsymmetricQuantizeInt8) { int8_t goldens[] = {-20, -5, -3, -3, -1, 1, 3, 5, 7, 9}; constexpr int length = sizeof(values) / sizeof(float); int8_t quantized[length]; - tflite::AsymmetricQuantize(values, quantized, length, 0.5, 1); + tflite::Quantize(values, quantized, length, 0.5, 1); for (int i = 0; i < length; i++) { TF_LITE_MICRO_EXPECT_EQ(quantized[i], goldens[i]); } @@ -95,7 +100,7 @@ TF_LITE_MICRO_TEST(AsymmetricQuantizeUInt8) { uint8_t goldens[] = {106, 121, 123, 123, 125, 127, 129, 131, 133, 135}; constexpr int length = sizeof(values) / sizeof(float); uint8_t quantized[length]; - tflite::AsymmetricQuantize(values, quantized, length, 0.5, 127); + tflite::Quantize(values, quantized, length, 0.5, 127); for (int i = 0; i < length; i++) { TF_LITE_MICRO_EXPECT_EQ(quantized[i], goldens[i]); } diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 82a57890231..d83df27ebca 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -870,101 +870,17 @@ TfLiteFloatArray* FloatArrayFromFloats(const float* floats) { return reinterpret_cast(const_cast(floats)); } -TfLiteTensor CreateTensor(TfLiteIntArray* dims, bool is_variable) { - TfLiteTensor result; - result.dims = dims; - result.params = {}; - result.quantization = {kTfLiteNoQuantization, nullptr}; - result.is_variable = is_variable; - result.allocation_type = kTfLiteMemNone; - return result; -} - -TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteFloat32; - result.data.f = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(float); - return result; -} - -void PopulateFloatTensor(TfLiteTensor* tensor, float* begin, float* end) { - float* p = begin; - float* v = tensor->data.f; - while (p != end) { - *v++ = *p++; - } -} - -TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteBool; - result.data.b = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(bool); - return result; -} - -TfLiteTensor CreateInt32Tensor(const int32_t* data, TfLiteIntArray* dims, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt32; - result.data.i32 = const_cast(data); - result.bytes = ElementCount(*dims) * sizeof(int32_t); - return result; -} - -TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteUInt8; - result.data.uint8 = const_cast(data); - result.params = {scale, zero_point}; - result.quantization = {kTfLiteAffineQuantization, nullptr}; - result.bytes = ElementCount(*dims) * sizeof(uint8_t); - return result; -} - -TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt8; - result.data.int8 = const_cast(data); - result.params = {scale, zero_point}; - result.quantization = {kTfLiteAffineQuantization, nullptr}; - result.bytes = ElementCount(*dims) * sizeof(int8_t); - return result; -} - -TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable) { - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt16; - result.data.i16 = const_cast(data); - result.params = {scale, zero_point}; - result.quantization = {kTfLiteAffineQuantization, nullptr}; - result.bytes = ElementCount(*dims) * sizeof(int16_t); - return result; -} - TfLiteTensor CreateQuantizedBiasTensor(const float* data, int32_t* quantized, TfLiteIntArray* dims, float input_scale, float weights_scale, bool is_variable) { float bias_scale = input_scale * weights_scale; tflite::SymmetricQuantize(data, quantized, ElementCount(*dims), bias_scale); - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt32; - result.data.i32 = const_cast(quantized); + // Quantized int32_t tensors always have a zero point of 0, since the range of // int32_t values is large, and because zero point costs extra cycles during // processing. - result.params = {bias_scale, 0}; - result.quantization = {kTfLiteAffineQuantization, nullptr}; - result.bytes = ElementCount(*dims) * sizeof(int32_t); + TfLiteTensor result = + CreateQuantizedTensor(quantized, dims, bias_scale, 0, is_variable); return result; } @@ -986,18 +902,15 @@ TfLiteTensor CreatePerChannelQuantizedBiasTensor( zero_points[i + 1] = 0; } - SymmetricPerChannelQuantize(input, quantized, input_size, num_channels, - scales_array); + SymmetricPerChannelQuantize(input, quantized, input_size, + num_channels, scales_array); affine_quant->scale = FloatArrayFromFloats(scales); affine_quant->zero_point = IntArrayFromInts(zero_points); affine_quant->quantized_dimension = quantized_dimension; - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt32; - result.data.i32 = const_cast(quantized); + TfLiteTensor result = CreateTensor(quantized, dims, is_variable); result.quantization = {kTfLiteAffineQuantization, affine_quant}; - result.bytes = ElementCount(*dims) * sizeof(int32_t); return result; } @@ -1020,11 +933,8 @@ TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( affine_quant->zero_point = IntArrayFromInts(zero_points); affine_quant->quantized_dimension = quantized_dimension; - TfLiteTensor result = CreateTensor(dims, is_variable); - result.type = kTfLiteInt8; - result.data.int8 = const_cast(quantized); + TfLiteTensor result = CreateTensor(quantized, dims, is_variable); result.quantization = {kTfLiteAffineQuantization, affine_quant}; - result.bytes = ElementCount(*dims) * sizeof(int8_t); return result; } diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 57c6c365662..1db0d81facc 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -22,10 +22,12 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite//kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -140,35 +142,42 @@ TfLiteIntArray* IntArrayFromInts(const int* int_array); // supplied array must be the size of the array expressed as a float. TfLiteFloatArray* FloatArrayFromFloats(const float* floats); -TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, - bool is_variable = false); +template +TfLiteTensor CreateTensor(const T* data, TfLiteIntArray* dims, + const bool is_variable = false) { + TfLiteTensor result; + result.dims = dims; + result.params = {}; + result.quantization = {kTfLiteNoQuantization, nullptr}; + result.is_variable = is_variable; + result.allocation_type = kTfLiteMemNone; + result.type = typeToTfLiteType(); + // Const cast is used to allow passing in const and non-const arrays within a + // single CreateTensor method. A Const array should be used for immutable + // input tensors and non-const array should be used for mutable and output + // tensors. + result.data.data = const_cast(data); + result.quantization = {kTfLiteAffineQuantization, nullptr}; + result.bytes = ElementCount(*dims) * sizeof(T); + return result; +} -void PopulateFloatTensor(TfLiteTensor* tensor, float* begin, float* end); - -TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims, - bool is_variable = false); - -TfLiteTensor CreateInt32Tensor(const int32_t*, TfLiteIntArray* dims, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims, - float scale, int zero_point, - bool is_variable = false); +template +TfLiteTensor CreateQuantizedTensor(const T* data, TfLiteIntArray* dims, + const float scale, const int zero_point = 0, + const bool is_variable = false) { + TfLiteTensor result = CreateTensor(data, dims, is_variable); + result.params = {scale, zero_point}; + result.quantization = {kTfLiteAffineQuantization, nullptr}; + return result; +} template TfLiteTensor CreateQuantizedTensor(const float* input, T* quantized, TfLiteIntArray* dims, float scale, int zero_point, bool is_variable = false) { int input_size = ElementCount(*dims); - tflite::AsymmetricQuantize(input, quantized, input_size, scale, zero_point); + tflite::Quantize(input, quantized, input_size, scale, zero_point); return CreateQuantizedTensor(quantized, dims, scale, zero_point, is_variable); } diff --git a/tensorflow/lite/micro/testing/stm32f4.robot b/tensorflow/lite/micro/testing/stm32f4.robot index d1d204f51e9..0833c0b0e11 100644 --- a/tensorflow/lite/micro/testing/stm32f4.robot +++ b/tensorflow/lite/micro/testing/stm32f4.robot @@ -17,7 +17,7 @@ Should Run Stm32f4 Test Execute Command $bin = @${BIN} Execute Script ${SCRIPT} - Create Terminal Tester ${UART} timeout=30 + Create Terminal Tester ${UART} timeout=60 Start Emulation Wait For Line On Uart ${EXPECTED} diff --git a/tensorflow/lite/micro/tools/ci_build/test_all.sh b/tensorflow/lite/micro/tools/ci_build/test_all.sh index 354d26d9102..a31b5d1382f 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_all.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_all.sh @@ -52,7 +52,7 @@ tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh PRESUBMIT echo "Running Arduino tests at `date`" tensorflow/lite/micro/tools/ci_build/test_arduino.sh -echo "Running cortex_m_gcc_generic tests at `date`" -tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh +echo "Running cortex_m_generic tests at `date`" +tensorflow/lite/micro/tools/ci_build/test_cortex_m_generic.sh echo "Finished all micro tests at `date`" diff --git a/tensorflow/lite/micro/tools/ci_build/test_arduino.sh b/tensorflow/lite/micro/tools/ci_build/test_arduino.sh index e333e9e6cd9..006951e4cf0 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_arduino.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_arduino.sh @@ -25,14 +25,15 @@ cd "${ROOT_DIR}" source tensorflow/lite/micro/tools/ci_build/helper_functions.sh -readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean clean_downloads +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean TARGET=arduino +TAGS=cmsis-nn # TODO(b/143715361): parallel builds do not work with generated files right now. readable_run make -f tensorflow/lite/micro/tools/make/Makefile \ TARGET=${TARGET} \ - TAGS="cmsis-nn" \ + TAGS=${TAGS} \ generate_arduino_zip readable_run tensorflow/lite/micro/tools/ci_build/install_arduino_cli.sh diff --git a/tensorflow/lite/micro/tools/ci_build/test_arduino_library.sh b/tensorflow/lite/micro/tools/ci_build/test_arduino_library.sh index 8770ea96980..3856cb849f8 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_arduino_library.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_arduino_library.sh @@ -33,7 +33,7 @@ rm -rf ${TEMP_BUILD_DIR} mkdir -p "${ARDUINO_HOME_DIR}/libraries" mkdir -p ${TEMP_BUILD_DIR} -unzip -q ${LIBRARY_ZIP} -d "${ARDUINO_LIBRARIES_DIR}" +unzip -o -q ${LIBRARY_ZIP} -d "${ARDUINO_LIBRARIES_DIR}" # Installs all dependencies for Arduino InstallLibraryDependencies () { @@ -51,7 +51,7 @@ InstallLibraryDependencies () { # commit is tested to work; if we bump the commit, we need to ensure that # the defines in ArduCAM/memorysaver.h are correct. wget -O /tmp/arducam-master.zip https://github.com/ArduCAM/Arduino/archive/e216049ba304048ec9bb29adfc2cc24c16f589b1/master.zip - unzip /tmp/arducam-master.zip -d /tmp + unzip -o /tmp/arducam-master.zip -d /tmp cp -r /tmp/Arduino-e216049ba304048ec9bb29adfc2cc24c16f589b1/ArduCAM "${ARDUINO_LIBRARIES_DIR}" } diff --git a/tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_generic.sh similarity index 68% rename from tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh rename to tensorflow/lite/micro/tools/ci_build/test_cortex_m_generic.sh index 596c88965e7..f2a43f72630 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_generic.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,23 +24,24 @@ cd "${ROOT_DIR}" source tensorflow/lite/micro/tools/ci_build/helper_functions.sh -TARGET=cortex_m_gcc_generic +TARGET=cortex_m_generic +TAGS=cmsis-nn # TODO(b/143715361): downloading first to allow for parallel builds. -readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F third_party_downloads +readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} TARGET_ARCH=cortex-m4 third_party_downloads # Build for Cortex-M4 (no FPU) without CMSIS readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean -readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4 microlite +readable_run make -j$(nproc) -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} TARGET_ARCH=cortex-m4 microlite # Build for Cortex-M4F (FPU present) without CMSIS readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean -readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4F microlite +readable_run make -j$(nproc) -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} TARGET_ARCH=cortex-m4+fp microlite # Build for Cortex-M4 (no FPU) with CMSIS readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean -readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4 microlite +readable_run make -j$(nproc) -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} TARGET_ARCH=cortex-m4 microlite # Build for Cortex-M4 (FPU present) with CMSIS readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean -readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F microlite +readable_run make -j$(nproc) -f tensorflow/lite/micro/tools/make/Makefile TAGS=${TAGS} TARGET=${TARGET} TARGET_ARCH=cortex-m4+fp microlite diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index fa1aa3f0baf..49d7b66ce0b 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -37,6 +37,7 @@ TARGET := $(HOST_OS) TARGET_ARCH := $(HOST_ARCH) # Default compiler and tool names: +TOOLCHAIN:=gcc CXX_TOOL := g++ CC_TOOL := gcc AR_TOOL := ar @@ -123,15 +124,21 @@ CXXFLAGS := \ -fno-threadsafe-statics \ $(COMMON_FLAGS) -CCFLAGS := \ - -std=c11 \ - $(COMMON_FLAGS) +CCFLAGS := \ + -std=c11 \ + $(COMMON_FLAGS) ARFLAGS := -r -LDFLAGS += \ - -Wl,--fatal-warnings \ - -Wl,--gc-sections +ifeq ($(TOOLCHAIN), gcc) + ifneq ($(TARGET), osx) + # GCC on MacOS uses an LLVM backend so we avoid the additional linker flags + # that are unsupported with LLVM. + LDFLAGS += \ + -Wl,--fatal-warnings \ + -Wl,--gc-sections + endif +endif # override these in the makefile.inc for specific compiler targets TARGET_TOOLCHAIN_PREFIX := @@ -358,14 +365,17 @@ $(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_I # The target-specific makefile must have a name that is exactly # TARGET_makefile.inc and is only needed for cross-compilation (i.e. when TARGET # is different from the HOST_OS). -ifneq ($(TARGET),$(HOST_OS)) - # The arduino is also special in that it does not have an arduino_makefile but - # the TARGET=arduino is still used to create a directory for the generated - # artifacts. We are using a workaround right now and will be separating the - # project generation from the Makefile in the future. - ifneq ($(TARGET),arduino) +# There are also some other targets like arduino and CHRE that are also special +# in that they do no have a _makefile but are still used to create a +# directory for the generated artifacts. We are using a workaround right now and +# will be separating the project generation from the Makefile in the future. +TARGETS_WITHOUT_MAKEFILES := \ +$(HOST_OS) \ +arduino \ +chre + +ifeq ($(findstring $(TARGET),$(TARGETS_WITHOUT_MAKEFILES)),) include $(MAKEFILE_DIR)/targets/$(TARGET)_makefile.inc - endif endif # Load dependencies for optimized kernel implementations. diff --git a/tensorflow/lite/micro/tools/make/download_and_extract.sh b/tensorflow/lite/micro/tools/make/download_and_extract.sh index da98ab1b042..f384e6afb4d 100755 --- a/tensorflow/lite/micro/tools/make/download_and_extract.sh +++ b/tensorflow/lite/micro/tools/make/download_and_extract.sh @@ -91,100 +91,44 @@ patch_cmsis() { # custom include paths. # These include changes were found through trial and error while trying to get the Arduino # library compiling with the CMSIS-NN kernels included. - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.c' -exec \ - sed -i -E $'s@#include "arm_nnfunctions.h"@#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h"@g' {} \; + dspfiles="arm_math.h" + dspfiles+="\|arm_math_types.h" + dspfiles+="\|arm_math_memory.h" + dspfiles+="\|arm_common_tables.h" + dspfiles+="\|dsp/basic_math_functions.h" + dspfiles+="\|dsp/bayes_functions.h" + dspfiles+="\|dsp/complex_math_functions.h" + dspfiles+="\|dsp/controller_functions.h" + dspfiles+="\|dsp/distance_functions.h" + dspfiles+="\|dsp/fast_math_functions.h" + dspfiles+="\|dsp/filtering_functions.h" + dspfiles+="\|dsp/interpolation_functions.h" + dspfiles+="\|dsp/matrix_functions.h" + dspfiles+="\|dsp/none.h" + dspfiles+="\|dsp/statistics_functions.h" + dspfiles+="\|dsp/support_functions.h" + dspfiles+="\|dsp/svm_functions.h" + dspfiles+="\|dsp/svm_defines.h" + dspfiles+="\|dsp/transform_functions.h" + dspfiles+="\|dsp/utils.h" + dspfiles+="\|dsp/arm_helium_utils.h" find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.c' -exec \ - sed -i -E $'s@#include "arm_nnsupportfunctions.h"@#include "cmsis/CMSIS/NN/Include/arm_nnsupportfunctions.h"@g' {} \; + \( -name *.c -or -name *.h -or -name *.cpp \) -exec \ + sed -i "s@#include \"\($dspfiles\)\"@#include \"cmsis/CMSIS/DSP/Include/\1\"@g" {} \; + nnfiles="arm_nn_tables.h" + nnfiles+="\|arm_nnfunctions.h" + nnfiles+="\|arm_nnsupportfunctions.h" + nnfiles+="\|arm_nn_types.h" find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.c' -exec \ - sed -i -E $'s@#include "arm_nn_types.h"@#include "cmsis/CMSIS/NN/Include/arm_nn_types.h"@g' {} \; + \( -name *.c -or -name *.h -or -name *.cpp \) -exec \ + sed -i "s@#include \"\($nnfiles\)\"@#include \"cmsis/CMSIS/NN/Include/\1\"@g" {} \; + corefiles="cmsis_compiler.h" find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "arm_math.h"@#include "cmsis/CMSIS/DSP/Include/arm_math.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "arm_common_tables.h"@#include "cmsis/CMSIS/DSP/Include/arm_common_tables.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "arm_nn_tables.h"@#include "cmsis/CMSIS/NN/Include/arm_nn_tables.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "arm_math_types.h"@#include "cmsis/CMSIS/DSP/Include/arm_math_types.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "arm_math_memory.h"@#include "cmsis/CMSIS/DSP/Include/arm_math_memory.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/none.h"@#include "cmsis/CMSIS/DSP/Include/dsp/none.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/utils.h"@#include "cmsis/CMSIS/DSP/Include/dsp/utils.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/basic_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/basic_math_functions.h"@g' {} \; - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/statistics_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/statistics_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/svm_defines.h"@#include "cmsis/CMSIS/DSP/Include/dsp/svm_defines.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/svm_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/svm_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/support_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/support_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/transform_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/transform_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/bayes_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/bayes_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/complex_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/complex_math_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/controller_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/controller_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/distance_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/distance_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/fast_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/fast_math_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/filtering_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/filtering_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/interpolation_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/interpolation_functions.h"@g' {} \; - - find tensorflow/lite/micro/tools/make/downloads/cmsis \ - -iname '*.*' -exec \ - sed -i -E $'s@#include "dsp/matrix_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/matrix_functions.h"@g' {} \; + \( -name *.c -or -name *.h -or -name *.cpp \) -exec \ + sed -i "s@#include \"\($corefiles\)\"@#include \"cmsis/CMSIS/Core/Include/\1\"@g" {} \; # Until the fix for https://github.com/ARMmbed/mbed-os/issues/12568 is # rolled into Mbed version used on the Arduino IDE, we have to replace @@ -295,6 +239,7 @@ download_and_extract() { fi else echo "Error unsupported archive type. Failed to extract tool after download." + exit 1 fi rm -rf ${tempdir2} ${tempdir} diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc similarity index 95% rename from tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc rename to tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc index b988dac0c94..b838cd68d13 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc @@ -8,7 +8,7 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) THIRD_PARTY_DOWNLOADS += \ $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) - CMSIS_PATH = $(MAKEFILE_DIR)/downloads/cmsis/ + CMSIS_PATH := $(MAKEFILE_DIR)/downloads/cmsis/ # List of files generated with: # find tensorflow/lite/micro/tools/make/downloads/cmsis/CMSIS/NN/Source/ -iname "*.c" @@ -92,23 +92,15 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) # optimized kernels. We don't include all the possible CMSIS headers because # of their large number. See the RFC document for more details: # https://docs.google.com/document/d/14GRxeVEgSKgKBKAijO7oxnI49nLoTYBFQmPok-rG0cw + # Note: If you add a .h here, you must update patch_cmsis() in download_and_extract.sh as well. THIRD_PARTY_CC_HDRS += \ - $(CMSIS_PATH)CMSIS/NN/Include/arm_nnfunctions.h \ - $(CMSIS_PATH)CMSIS/NN/Include/arm_nnsupportfunctions.h \ - $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_tables.h \ - $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_types.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/arm_math.h \ + $(CMSIS_PATH)CMSIS/Core/Include/cmsis_compiler.h \ $(CMSIS_PATH)CMSIS/DSP/Include/arm_common_tables.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/arm_math_types.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/arm_helium_utils.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/arm_math.h \ $(CMSIS_PATH)CMSIS/DSP/Include/arm_math_memory.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/none.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/utils.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/arm_math_types.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/basic_math_functions.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/statistics_functions.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_defines.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_functions.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/support_functions.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/transform_functions.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/bayes_functions.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/complex_math_functions.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/controller_functions.h \ @@ -116,8 +108,18 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) $(CMSIS_PATH)CMSIS/DSP/Include/dsp/fast_math_functions.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/filtering_functions.h \ $(CMSIS_PATH)CMSIS/DSP/Include/dsp/interpolation_functions.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h - + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/none.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/statistics_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/support_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_defines.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/transform_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/utils.h \ + $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_tables.h \ + $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_types.h \ + $(CMSIS_PATH)CMSIS/NN/Include/arm_nnfunctions.h \ + $(CMSIS_PATH)CMSIS/NN/Include/arm_nnsupportfunctions.h # Need to add the CMSIS Core includes path. # All other CMSIS header files are included with their relative path diff --git a/tensorflow/lite/micro/tools/make/ext_libs/ethosu.inc b/tensorflow/lite/micro/tools/make/ext_libs/ethosu.inc index 44de3ebfc7c..c136b3e7d1f 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/ethosu.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/ethosu.inc @@ -1,19 +1,37 @@ ifneq ($(filter ethos-u,$(ALL_TAGS)),) - # Don't want -lm flag - MICROLITE_LIBS := + # Arm Compiler will not link the Math library (see below), therefore we're filtering it out. + # See Fatal error: L6450U: Cannot find library m: + # "Arm Compiler is designed to run in a bare metal environment, + # and automatically includes implementations of these functions, + # and so no such flag is necessary." + # https://developer.arm.com/documentation/100891/0611/troubleshooting/general-troubleshooting-advice + MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) ifneq (,$(filter $(TARGET_ARCH), x86_64)) $(error target architecture x86_64 not supported) endif - THIRD_PARTY_DOWNLOADS += \ - $(eval $(call add_third_party_download,$(ETHOSU_URL),$(ETHOSU_MD5),ethosu,)) ETHOSU_DRIVER_PATH = $(MAKEFILE_DIR)/downloads/ethosu + # The driver need to be downloaded before the recursive_find below. + # That won't happen with the standard way of downloading by generating a + # target(call add_third_party_download), so instead use the shell function. + NEED_DOWNLOAD := YES + ifeq ($(NEED_DOWNLOAD),$(shell test -d $(ETHOSU_DRIVER_PATH) || echo $(NEED_DOWNLOAD))) + DOWNLOAD_SCRIPT := ./tensorflow/lite/micro/tools/make/download_and_extract.sh + DOWNLOAD_OK := OK + DOWNLOAD_STATUS := $(shell $(DOWNLOAD_SCRIPT) $(ETHOSU_URL) $(ETHOSU_MD5) $(ETHOSU_DRIVER_PATH) >&2 && echo $(DOWNLOAD_OK)) + ifneq ($(DOWNLOAD_OK),$(DOWNLOAD_STATUS)) + $(error $(DOWNLOAD_SCRIPT) failed) + endif + endif + # Currently there is a dependency to CMSIS-NN THIRD_PARTY_DOWNLOADS += \ $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) - CMSIS_PATH = $(MAKEFILE_DIR)/downloads/cmsis/ + ifeq ($(CMSIS_PATH),) + CMSIS_PATH = $(MAKEFILE_DIR)/downloads/cmsis/ + endif THIRD_PARTY_CC_HDRS += $(call recursive_find,$(CMSIS_PATH)/CMSIS/Core/Include,*.h) THIRD_PARTY_CC_HDRS += $(call recursive_find,$(ETHOSU_DRIVER_PATH)/include,*.h) diff --git a/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc b/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc index 28c0fcd8571..c396c1076f3 100644 --- a/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc +++ b/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc @@ -123,6 +123,10 @@ endif CXXFLAGS := $(filter-out -std=c++11,$(CXXFLAGS)) CCFLAGS := $(filter-out -std=c11,$(CCFLAGS)) + + ldflags_to_remove = -Wl,--fatal-warnings -Wl,--gc-sections + LDFLAGS := $(filter-out $(ldflags_to_remove),$(LDFLAGS)) + MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) CXXFLAGS += $(PLATFORM_FLAGS) diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc deleted file mode 100644 index dd7ccca7ba5..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc +++ /dev/null @@ -1,31 +0,0 @@ -# Generic Makefile target for ARM Cortex Mx gcc builds. -ifeq ($(TARGET), cortex_m_gcc_generic) - TARGET_ARCH := arm - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- - export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) - - $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) - - PLATFORM_FLAGS = \ - -DTF_LITE_MCU_DEBUG_LOG \ - -Wno-type-limits \ - -funsigned-char \ - -mcpu=cortex-m4 \ - -mfpu=fpv4-sp-d16 \ - -mthumb \ - -fomit-frame-pointer - -ifeq ($(CORTEX_M_CORE), M4F) - PLATFORM_FLAGS += -mfloat-abi=hard -else ifeq ($(CORTEX_M_CORE), M4) - PLATFORM_FLAGS += -mfloat-abi=softfp -else ifeq ($(CORTEX_M_CORE), ) - $(error CORTEX_M_CORE=[M4|M4F] not defined on the command line) -else - $(error invalid target defined in command line option CORTEX_M_CORE=[M4|M4F]) -endif - - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - -endif diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m_generic_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m_generic_makefile.inc new file mode 100644 index 00000000000..6747ab9fc36 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/cortex_m_generic_makefile.inc @@ -0,0 +1,128 @@ +# Generic Makefile target for ARM Cortex M builds. +# For more info see: tensorflow/lite/micro/cortex_m_generic/README.md + +FLOAT := soft +GCC_TARGET_ARCH := $(TARGET_ARCH) + +ifeq ($(TARGET_ARCH), cortex-m0) + CORE=M0 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M0 + +else ifeq ($(TARGET_ARCH), cortex-m3) + CORE=M3 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M3 + +else ifeq ($(TARGET_ARCH), cortex-m33) + CORE=M33 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M33 + TARGET_SPECIFIC_FLAGS += -D__DSP_PRESENT=1 -D__FPU_PRESENT=1 -D__VTOR_PRESENT=1 -D__FPU_USED=1 + FLOAT=hard + +else ifeq ($(TARGET_ARCH), cortex-m33+nodsp) + CORE=M33 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M33.no_dsp.no_fp + +else ifeq ($(TARGET_ARCH), cortex-m4) + CORE=M4 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M4.no_fp + +else ifeq ($(TARGET_ARCH), cortex-m4+fp) + CORE=M4 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M4 + TARGET_SPECIFIC_FLAGS += -D__FPU_PRESENT=1 -mfpu=fpv4-sp-d16 + FLOAT=hard + GCC_TARGET_ARCH := cortex-m4 + +else ifeq ($(TARGET_ARCH), cortex-m55) + CORE=M55 + ARM_LDFLAGS := -Wl,--cpu=8.1-M.Main.mve.fp + TARGET_SPECIFIC_FLAGS += -D__DSP_PRESENT=1 -D__FPU_PRESENT=1 + FLOAT=hard + +else ifeq ($(TARGET_ARCH), cortex-m55+nodsp+nofp) + CORE=M55 + ARM_LDFLAGS := -Wl,--cpu=8.1-M.Main.mve.no_dsp.no_fp + +else ifeq ($(TARGET_ARCH), cortex-m55+nofp) + CORE=M55 + ARM_LDFLAGS := -Wl,--cpu=8.1-M.Main.mve.no_fp + TARGET_SPECIFIC_FLAGS += -D__DSP_PRESENT=1 + +else ifeq ($(TARGET_ARCH), cortex-m7) + CORE=M7 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M7.no_fp + +else ifeq ($(TARGET_ARCH), cortex-m7+fp) + CORE=M7 + ARM_LDFLAGS := -Wl,--cpu=Cortex-M7 + FLOAT=hard + GCC_TARGET_ARCH := cortex-m7 + +else + $(error "TARGET_ARCH=$(TARGET_ARCH) is not supported") +endif + +ifneq ($(filter cortex-m55%,$(TARGET_ARCH)),) + ifeq ($(TOOLCHAIN), gcc) + $(error "Micro architecure support is not available for arm-gcc for TARGET_ARCH=$(TARGET_ARCH)") + endif + + # soft-abi=soft disables MVE - use softfp instead for M55. + ifeq ($(FLOAT),soft) + FLOAT=softfp + endif +endif + +# Toolchain specfic flags +ifeq ($(TOOLCHAIN), armclang) + CXX_TOOL := armclang + CC_TOOL := armclang + AR_TOOL := armar + LD := armlink + + FLAGS_ARMC = \ + --target=arm-arm-none-eabi \ + -mcpu=$(TARGET_ARCH) + + CXXFLAGS += $(FLAGS_ARMC) + CCFLAGS += $(FLAGS_ARMC) + LDFLAGS += $(ARM_LDFLAGS) + + # Arm Compiler will not link the Math library (see below), therefore we're filtering it out. + # See Fatal error: L6450U: Cannot find library m: + # "Arm Compiler is designed to run in a bare metal environment, + # and automatically includes implementations of these functions, + # and so no such flag is necessary." + # https://developer.arm.com/documentation/100891/0611/troubleshooting/general-troubleshooting-advice + MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) + +else ifeq ($(TOOLCHAIN), gcc) + export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) + $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) + + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + FLAGS_GCC = -mcpu=$(GCC_TARGET_ARCH) + CXXFLAGS += $(FLAGS_GCC) + CCFLAGS += $(FLAGS_GCC) + +else + $(error "TOOLCHAIN=$(TOOLCHAIN) is not supported.") +endif + +PLATFORM_FLAGS = \ + -DTF_LITE_MCU_DEBUG_LOG \ + -mthumb \ + -mfloat-abi=$(FLOAT) \ + -funsigned-char \ + -mlittle-endian \ + -Wno-type-limits \ + -Wno-unused-private-field \ + -fomit-frame-pointer \ + -MD \ + -DCPU_$(CORE)=1 \ + $(TARGET_SPECIFIC_FLAGS) + +# Common + C/C++ flags +CXXFLAGS += $(PLATFORM_FLAGS) +CCFLAGS += $(PLATFORM_FLAGS) diff --git a/tensorflow/lite/micro/tools/make/targets/himax_we1_evb_makefile.inc b/tensorflow/lite/micro/tools/make/targets/himax_we1_evb_makefile.inc index d19ce680b41..11c39867e31 100644 --- a/tensorflow/lite/micro/tools/make/targets/himax_we1_evb_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/himax_we1_evb_makefile.inc @@ -87,6 +87,10 @@ ifeq ($(TARGET), himax_we1_evb) CXXFLAGS := $(filter-out -std=c++11,$(CXXFLAGS)) CCFLAGS := $(filter-out -std=c11,$(CCFLAGS)) + + ldflags_to_remove = -Wl,--fatal-warnings -Wl,--gc-sections + LDFLAGS := $(filter-out $(ldflags_to_remove),$(LDFLAGS)) + MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) endif diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds b/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds index 8e8b3f75448..1856368b4dc 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4/stm32f4.lds @@ -14,26 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* Copied and modified from: tensorflow/lite/micro/tools/make/targets/bluepill/bluepill.lds - -*/ - -/* - * 0x00000000 - 0x07ffffff - aliased to flash or sys memory depending on BOOT jumpers. - * 0x08000000 - 0x0801ffff - Flash. - * 0x1ffff000 - 0x1ffff7ff - Boot firmware in system memory. - * 0x1ffff800 - 0x1fffffff - Option bytes. - * 0x20000000 - 0x20004fff - SRAM. - * 0x40000000 - 0x40023400 - Peripherals - */ - /* Define main entry point */ ENTRY(_main) -/* 32K of RAM and 256K of FLASH */ +/* 256K of RAM and 2048K of FLASH. Source: */ +/* https://github.com/renode/renode/blob/master/platforms/cpus/stm32f4.repl*/ MEMORY { -RAM (xrw) : ORIGIN = 0x20000000, LENGTH = 32K -FLASH (rx) : ORIGIN = 0x8000000, LENGTH = 256K + RAM (xrw) : ORIGIN = 0x20000000, LENGTH = 256K + FLASH (rx) : ORIGIN = 0x8000000, LENGTH = 2048K } /* Compute where the stack ends rather than hard coding it */ diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc index e9ee7296999..81cef4465b1 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc @@ -1,87 +1,82 @@ # Settings for stm32f4 based platforms -ifeq ($(TARGET), stm32f4) - export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) - TARGET_ARCH := cortex-m4 - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- - TARGET_TOOLCHAIN_ROOT := $(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/downloads/gcc_embedded/bin/ - $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) - $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) - $(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,)) +export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) +TARGET_ARCH := cortex-m4 +TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +TARGET_TOOLCHAIN_ROOT := $(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/downloads/gcc_embedded/bin/ - # TODO(b/161478030) : change - Wno - vla to - Wvla and remove - Wno-shadow once - # we have a solution for fixing / avoiding being tripped up by these warnings. - PLATFORM_FLAGS = \ - -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ - -DTF_LITE_STATIC_MEMORY \ - -DTF_LITE_MCU_DEBUG_LOG \ - -fmessage-length=0 \ - -fno-exceptions \ - -fno-unwind-tables \ - -ffunction-sections \ - -fdata-sections \ - -funsigned-char \ - -MMD \ - -mcpu=cortex-m4 \ - -mthumb \ - -Wall \ - -Wextra \ - -Wno-shadow \ - -Wno-vla \ - -Wno-strict-aliasing \ - -Wno-type-limits \ - -Wno-unused-parameter \ - -Wno-missing-field-initializers \ - -Wno-write-strings \ - -Wno-sign-compare \ - -Wunused-function \ - -fno-delete-null-pointer-checks \ - -fomit-frame-pointer \ - -g \ - -Os - CXXFLAGS += $(PLATFORM_FLAGS) -std=gnu++11 -fno-rtti -fno-use-cxa-atexit - CCFLAGS += $(PLATFORM_FLAGS) - LDFLAGS += \ - --specs=nosys.specs \ - -T ${TENSORFLOW_ROOT}$(MAKEFILE_DIR)/targets/stm32f4/stm32f4.lds \ - -Wl,-Map=${TENSORFLOW_ROOT}$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \ - -Wl,--gc-sections - BUILD_TYPE := micro - MICROLITE_LIBS := \ - -lm - INCLUDES += \ - -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ - -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include/ - THIRD_PARTY_CC_SRCS += \ - $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ - $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) - EXCLUDED_SRCS := \ - $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/debug_log.c - THIRD_PARTY_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(THIRD_PARTY_CC_SRCS)) - MICROLITE_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(MICROLITE_CC_SRCS)) - TEST_SCRIPT := tensorflow/lite/micro/testing/test_stm32f4_binary.sh - # TODO, non working tests.. the micro_speech example partly works - # TODO(b/158324045): Examine why some tests fail here. +$(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) +$(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) +$(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,)) - EXCLUDED_TESTS := \ - tensorflow/lite/micro/micro_interpreter_test.cc \ - tensorflow/lite/micro/micro_allocator_test.cc \ - tensorflow/lite/micro/memory_helpers_test.cc \ - tensorflow/lite/micro/memory_arena_threshold_test.cc \ - tensorflow/lite/micro/recording_micro_allocator_test.cc \ - tensorflow/lite/micro/kernels/circular_buffer_test.cc \ - tensorflow/lite/micro/kernels/conv_test.cc \ - tensorflow/lite/micro/kernels/fully_connected_test.cc +# TODO(b/161478030) : change - Wno - vla to - Wvla and remove - Wno-shadow once +# we have a solution for fixing / avoiding being tripped up by these warnings. +PLATFORM_FLAGS = \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTF_LITE_STATIC_MEMORY \ + -DTF_LITE_MCU_DEBUG_LOG \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-unwind-tables \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m4 \ + -mthumb \ + -Wall \ + -Wextra \ + -Wno-shadow \ + -Wno-vla \ + -Wno-strict-aliasing \ + -Wno-type-limits \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -Wno-write-strings \ + -Wno-sign-compare \ + -Wunused-function \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -g \ + -Os +CXXFLAGS += $(PLATFORM_FLAGS) -std=gnu++11 -fno-rtti -fno-use-cxa-atexit +CCFLAGS += $(PLATFORM_FLAGS) +LDFLAGS += \ + --specs=nosys.specs \ + -T ${TENSORFLOW_ROOT}$(MAKEFILE_DIR)/targets/stm32f4/stm32f4.lds \ + -Wl,-Map=${TENSORFLOW_ROOT}$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \ + -Wl,--gc-sections +BUILD_TYPE := micro +MICROLITE_LIBS := \ + -lm +INCLUDES += \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include/ +THIRD_PARTY_CC_SRCS += \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) +EXCLUDED_SRCS := \ + $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/debug_log.c +THIRD_PARTY_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(THIRD_PARTY_CC_SRCS)) +MICROLITE_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(MICROLITE_CC_SRCS)) +TEST_SCRIPT := tensorflow/lite/micro/testing/test_stm32f4_binary.sh - MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) +# TODO(b/158324045): Examine why some tests fail here. +EXCLUDED_TESTS := \ + tensorflow/lite/micro/micro_interpreter_test.cc \ + tensorflow/lite/micro/micro_allocator_test.cc \ + tensorflow/lite/micro/memory_helpers_test.cc \ + tensorflow/lite/micro/memory_arena_threshold_test.cc \ + tensorflow/lite/micro/recording_micro_allocator_test.cc \ + tensorflow/lite/micro/kernels/circular_buffer_test.cc +MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) - EXCLUDED_EXAMPLE_TESTS := \ - tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection/Makefile.inc \ - tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc \ - tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ - tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc - MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) +EXCLUDED_EXAMPLE_TESTS := \ + tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ + tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ + tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc \ + tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc +MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) # These are microcontroller-specific rules for converting the ELF output # of the linker into a binary image that can be loaded directly. @@ -91,4 +86,3 @@ $(BINDIR)/%.bin: $(BINDIR)/% @mkdir -p $(dir $@) $(OBJCOPY) $< $@ -O binary -endif diff --git a/tensorflow/lite/micro/tools/make/templates/zephyr_cmake_project.cmake.tpl b/tensorflow/lite/micro/tools/make/templates/zephyr_cmake_project.cmake.tpl index d7bb4511f32..dc1eee5547d 100644 --- a/tensorflow/lite/micro/tools/make/templates/zephyr_cmake_project.cmake.tpl +++ b/tensorflow/lite/micro/tools/make/templates/zephyr_cmake_project.cmake.tpl @@ -2,13 +2,11 @@ cmake_minimum_required(VERSION 3.13.1) include($ENV{ZEPHYR_BASE}/cmake/app/boilerplate.cmake NO_POLICY_SCOPE) project(tf_lite_magic_wand) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} %{CXX_FLAGS}%") +# -fno-threadsafe-statics -- disables the mutex around initialization of local static variables +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} %{CXX_FLAGS}% -fno-threadsafe-statics -Wno-sign-compare -Wno-narrowing") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} %{CC_FLAGS}%") set(CMAKE_EXE_LINKER_FLAGS "%{LINKER_FLAGS}%") -# -fno-threadsafe-statics -- disables the mutex around initialization of local static variables -target_compile_options(app PRIVATE "-fno-threadsafe-statics") - target_sources(app PRIVATE %{SRCS}% ) diff --git a/tensorflow/lite/micro/xtensa_hifimini/debug_log.cc b/tensorflow/lite/micro/xtensa_hifimini/debug_log.cc deleted file mode 100644 index 45d9317478a..00000000000 --- a/tensorflow/lite/micro/xtensa_hifimini/debug_log.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Reference implementation of the DebugLog() function that's required for a -// platform to support the TensorFlow Lite for Microcontrollers library. This is -// the only function that's absolutely required to be available on a target -// device, since it's used for communicating test results back to the host so -// that we can verify the implementation is working correctly. -// It's designed to be as easy as possible to supply an implementation though. -// On platforms that have a POSIX stack or C library, it can be written as a -// single call to `fprintf(stderr, "%s", s)` to output a string to the error -// stream of the console, but if there's no OS or C library available, there's -// almost always an equivalent way to write out a string to some serial -// interface that can be used instead. For example on Arm M-series MCUs, calling -// the `bkpt #0xAB` assembler instruction will output the string in r1 to -// whatever debug serial connection is available. If you're running mbed, you -// can do the same by creating `Serial pc(USBTX, USBRX)` and then calling -// `pc.printf("%s", s)`. -// To add an equivalent function for your own platform, create your own -// implementation file, and place it in a subfolder with named after the OS -// you're targeting. For example, see the Cortex M bare metal version in -// tensorflow/lite/micro/bluepill/debug_log.cc or the mbed one on -// tensorflow/lite/micro/mbed/debug_log.cc. - -#include "tensorflow/lite/micro/debug_log.h" - -#ifndef TF_LITE_STRIP_ERROR_STRINGS -#include -#endif - -extern "C" void DebugLog(const char* s) { -#ifndef TF_LITE_STRIP_ERROR_STRINGS - // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get - // maximum reduction in binary size. This is because we have DebugLog calls - // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR. - fprintf(stderr, "%s", s); -#endif -} diff --git a/tensorflow/lite/micro/xtensa_hifimini/micro_time.cc b/tensorflow/lite/micro/xtensa_hifimini/micro_time.cc index 6f3844c1fe3..22880657882 100644 --- a/tensorflow/lite/micro/xtensa_hifimini/micro_time.cc +++ b/tensorflow/lite/micro/xtensa_hifimini/micro_time.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Xtensa implementation of micro_timer. -// To include this with make, add TAGS=xtensa-xpg. +// Xtensa timer implementation. +// To include this with make, add TARGET=xtensa_hifimini. #include "tensorflow/lite/micro/micro_time.h" #include diff --git a/tensorflow/lite/portable_type_to_tflitetype.h b/tensorflow/lite/portable_type_to_tflitetype.h index 208efcce5b2..32423a4474b 100644 --- a/tensorflow/lite/portable_type_to_tflitetype.h +++ b/tensorflow/lite/portable_type_to_tflitetype.h @@ -58,7 +58,7 @@ struct TfLiteTypeToType {}; // Specializations below // No string mapping is included here, since the TF Lite packed representation // doesn't correspond to a C++ type well. -MATCH_TYPE_AND_TFLITE_TYPE(int, kTfLiteInt32); +MATCH_TYPE_AND_TFLITE_TYPE(int32_t, kTfLiteInt32); MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16); MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64); MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32); diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 5c863886523..a49b488d5e9 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -233,6 +233,7 @@ py_library( deps = [ ":op_hint", ":schema_py", + ":schema_util", "//tensorflow/lite/toco:toco_flags_proto_py", "//tensorflow/python:convert_to_constants", "//tensorflow/python:dtypes", @@ -401,3 +402,13 @@ sh_test( srcs = ["convert_file_to_c_source_test.sh"], data = [":convert_file_to_c_source"], ) + +py_library( + name = "schema_util", + srcs = ["schema_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow/lite/schema:utils_friends"], + deps = [ + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index bcb338b84cf..62bd9710f23 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -218,6 +218,18 @@ class InterpreterTest(test_util.TensorFlowTestCase): output_data = interpreter.get_tensor(output_details[0]['index']) self.assertTrue((expected_output == output_data).all()) + def testStringZeroDim(self): + data = b'abcd' + bytes(16) + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/gather_string_0d.tflite')) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + interpreter.set_tensor(input_details[0]['index'], np.array(data)) + test_input_tensor = interpreter.get_tensor(input_details[0]['index']) + self.assertEqual(len(data), len(test_input_tensor.item(0))) + def testPerChannelParams(self): interpreter = interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin')) diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index d2f308a74a2..b854e2ebd69 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -153,6 +153,11 @@ bool FillStringBufferWithPyArray(PyObject* value, case NPY_OBJECT: case NPY_STRING: case NPY_UNICODE: { + if (PyArray_NDIM(array) == 0) { + dynamic_buffer->AddString(static_cast(PyArray_DATA(array)), + PyArray_NBYTES(array)); + return true; + } UniquePyObjectRef iter(PyArray_IterNew(value)); while (PyArray_ITER_NOTDONE(iter.get())) { UniquePyObjectRef item(PyArray_GETITEM( diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index eb1e5509874..362145435a9 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -192,6 +192,7 @@ class QuantizationMode(object): """Post training int8 quantize, disallow float fallback.""" return (self._is_int8_target_required() and not self._is_int16x8_target_required() and + not self._is_allow_float() and self._representative_dataset is not None) def post_training_int8_allow_float(self): @@ -211,16 +212,14 @@ class QuantizationMode(object): return (self.post_training_int16x8_no_float() or self.post_training_int16x8_allow_float()) - def is_post_training_integer_quantize(self): - """Post training integer quantization.""" + def is_integer_quantize(self): return (self.is_post_training_integer_quantize_8() or - self.is_post_training_integer_quantize_16x8()) + self.is_post_training_integer_quantize_16x8() or + self.is_training_time_int8_allow_float()) - def training_time_int8_allow_float(self): - """Training-time int8 quantize, allow float fallback.""" + def is_training_time_int8_allow_float(self): return (self._any_optimization_enabled() and - not self.post_training_dynamic_range_int8() and - not self.post_training_fp16()) + self.contains_training_quant_op()) def post_training_int16x8_no_float(self): """Post training int16x8 quantize, disallow float fallback.""" @@ -249,11 +248,7 @@ class QuantizationMode(object): def fp32_execution(self): """If none of the above are true.""" - return not (self.post_training_int8_no_float() or - self.post_training_int8_allow_float() or - self.training_time_int8_allow_float() or - self.post_training_int16x8_no_float() or - self.post_training_int16x8_allow_float() or + return not (self.is_integer_quantize() or self.post_training_dynamic_range_int8() or self.post_training_fp16()) @@ -263,17 +258,12 @@ class QuantizationMode(object): def converter_flags(self, inference_ty=None, inference_input_ty=None): """Flags to the converter.""" - if self.is_post_training_integer_quantize(): - # The inference_input_type is for the quantizer, then we need to keep the - # converter inference_input_type to float. - inference_input_ty = _dtypes.float32 - if self.training_time_int8_allow_float(): + if self.is_integer_quantize(): return { "inference_type": inference_ty if inference_ty else \ self.activations_type(), - "inference_input_type": - inference_input_ty if inference_input_ty else _dtypes.float32, + "inference_input_type": _dtypes.float32, "post_training_quantize": False, # disable dynamic range quantization "quantize_to_float16": False # disable float16 quantization } @@ -326,19 +316,13 @@ class QuantizationMode(object): else: return False, None - def flags_modify_model_io_type( - self, input_type=_dtypes.float32, output_type=_dtypes.float32): + def flags_modify_model_io_type(self, input_ty=None, output_ty=None): """Flags for modifying the input and output type of a tflite model.""" - is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0] - is_training_time_only_quantize = self.training_time_int8_allow_float() and \ - not is_post_training_quantize - # TODO(b/153576658): Consolidate post/during training quantization workflows - # to modify model input/output type after MLIR conversion. - if is_training_time_only_quantize: + if self.is_integer_quantize(): return { - "inference_input_type": input_type, - "inference_output_type": output_type, + "inference_input_type": input_ty if input_ty else _dtypes.float32, + "inference_output_type": output_ty if output_ty else _dtypes.float32, } else: return None @@ -368,20 +352,18 @@ class QuantizationMode(object): "TFLITE_BUILTINS_INT8 or INT8 supported types.") def _is_int8_target_required(self): - return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set( - self._target_spec.supported_ops) or - set(self._target_spec.supported_types) == set([_dtypes.int8])) + return (OpsSet.TFLITE_BUILTINS_INT8 in set( + self._target_spec.supported_ops)) or (set( + self._target_spec.supported_types) == set([_dtypes.int8])) def _is_int16x8_target_required(self): - return bool( - set(self._target_spec.supported_ops).intersection([ - OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 - ])) + return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + in set(self._target_spec.supported_ops)) def _is_allow_float(self): - return bool( - set(self._target_spec.supported_ops).intersection( - [OpsSet.TFLITE_BUILTINS])) + return (OpsSet.TFLITE_BUILTINS in set( + self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set( + self._target_spec.supported_ops)) def _any_optimization_enabled(self): return bool( @@ -462,6 +444,8 @@ class TFLiteConverterBase(object): self.representative_dataset = RepresentativeDataset( self.representative_dataset) + # Add intermediate tensors to the model if needed. + result = _calibrator.add_intermediate_tensors(result) calibrate_quantize = _calibrator.Calibrator(result) if self._experimental_calibrate_only or self._experimental_new_quantizer: calibrated = calibrate_quantize.calibrate( @@ -561,7 +545,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): """Validate inference_input_type and inference_output_type flags.""" default_types = [_dtypes.float32] # We support integer input/output for integer quantized models only. - if quant_mode.training_time_int8_allow_float(): + if quant_mode.is_integer_quantize(): if quant_mode.is_post_training_integer_quantize_16x8(): all_types = default_types + [_dtypes.int16] else: @@ -643,8 +627,7 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): output_tensors=output_tensors, **converter_kwargs) - calibrate_and_quantize, flags = quant_mode.quantizer_flags( - self.inference_input_type, self.inference_output_type) + calibrate_and_quantize, flags = quant_mode.quantizer_flags() if calibrate_and_quantize: result = self._calibrate_quantize_model(result, **flags) @@ -754,8 +737,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): converter_kwargs.update(quant_mode.converter_flags()) result = _convert_saved_model(**converter_kwargs) - calibrate_and_quantize, flags = quant_mode.quantizer_flags( - self.inference_input_type, self.inference_output_type) + calibrate_and_quantize, flags = quant_mode.quantizer_flags() if calibrate_and_quantize: result = self._calibrate_quantize_model(result, **flags) @@ -1305,8 +1287,11 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): "please file a bug. You can opt-out " "by setting experimental_new_converter=False") - calibrate_quantize, flags = quant_mode.quantizer_flags( - self.inference_input_type, self.inference_output_type) + if not self.experimental_new_converter: + calibrate_quantize, flags = quant_mode.quantizer_flags( + self.inference_input_type, self.inference_output_type) + else: + calibrate_quantize, flags = quant_mode.quantizer_flags() self._validate_quantized_input_stats(converter_kwargs, calibrate_quantize) @@ -1327,6 +1312,12 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): if calibrate_quantize: result = self._calibrate_quantize_model(result, **flags) + if self.experimental_new_converter: + flags_modify_model_io_type = quant_mode.flags_modify_model_io_type( + self.inference_input_type, self.inference_output_type) + if flags_modify_model_io_type: + result = _modify_model_io_type(result, **flags_modify_model_io_type) + if self._experimental_sparsify_model: result = _mlir_sparsify(result) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 4851912226a..db38287d9c2 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -488,35 +488,92 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): converter.convert() self._assertValidDebugInfo(converter._debug_info) + def _getIntegerQuantizationModelWithFlexOp(self): + np.random.seed(0) + + root = tracking.AutoTrackable() + + @tf.function(input_signature=[ + tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32) + ]) + def func(inp): + tanh = tf.math.tanh(inp) + conv3d = tf.nn.conv3d( + tanh, + tf.ones([3, 3, 3, 3, 3]), + strides=[1, 1, 1, 1, 1], + padding='SAME') + output = tf.math.tanh(conv3d) + return output + + def calibration_gen(): + for _ in range(5): + yield [ + np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32) + ] + + root.f = func + return (root.f.get_concrete_function(), calibration_gen) + + @parameterized.named_parameters( + ('_Default', False, False, dtypes.float32), + ('_INT8InputOutput', False, False, dtypes.int8), + ('_UINT8InputOutput', False, False, dtypes.uint8), + ('_INT16Quantize', False, True, dtypes.float32), + ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), + ('_IntOnly', True, False, dtypes.float32), + ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), + ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), + ('_IntOnly_INT16Quantize', True, True, dtypes.float32), + ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16)) @test_util.run_v2_only - def testFlexOpWithInt8OpSet(self): - model = tf.keras.Sequential() - input_shape = (1, 4, 4, 4, 1) - model.add( - tf.keras.layers.Conv3D( - 4, - kernel_size=(1, 1, 1), - activation='relu', - input_shape=input_shape[1:])) - model.add(tf.keras.layers.Flatten()) - model.add(tf.keras.layers.Dense(2, activation='relu')) + def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize, + inference_input_output_type): + func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp() - @tf.function( - input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) - def _call_fn(inputs): - return model(inputs, training=False) + quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions( + [func]) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter.representative_dataset = calibration_gen + if is_int_only: + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.SELECT_TF_OPS + ] + else: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS + ] + else: + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.TFLITE_BUILTINS, + lite.OpsSet.SELECT_TF_OPS + ] + else: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS + ] - concrete_func = _call_fn.get_concrete_function( - tf.TensorSpec(input_shape, dtype=tf.float32)) + quantized_converter.inference_input_type = inference_input_output_type + quantized_converter.inference_output_type = inference_input_output_type + quantized_tflite_model = quantized_converter.convert() + self.assertIsNotNone(quantized_tflite_model) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS_INT8, - tf.lite.OpsSet.SELECT_TF_OPS, - ] - tflite_model = converter.convert() - self.assertTrue(tflite_model) + interpreter = Interpreter(model_content=quantized_tflite_model) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + input_details[0]['dtype']) + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + output_details[0]['dtype']) class FromSavedModelTest(lite_v2_test_util.ModelTest): diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index b921fc45cde..c1956cc5b2d 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -16,6 +16,7 @@ cc_library( "//tensorflow/lite/python/interpreter_wrapper:numpy", "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/lite/python/interpreter_wrapper:python_utils", + "//tensorflow/lite/tools/optimize:quantization_wrapper_utils", "//tensorflow/lite/tools/optimize:quantize_model", "//tensorflow/lite/tools/optimize/calibration:calibration_reader", "//tensorflow/lite/tools/optimize/calibration:calibrator_lib", diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 7639e493e83..53d9aada15a 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h" #include "tensorflow/lite/tools/optimize/calibration/calibrator.h" +#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h" #include "tensorflow/lite/tools/optimize/quantize_model.h" #define TFLITE_PY_CHECK(x) \ @@ -94,6 +95,42 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { } // namespace +PyObject* AddIntermediateTensors(PyObject* data) { + using tflite::interpreter_wrapper::PythonErrorReporter; + char* buf = nullptr; + Py_ssize_t length; + std::unique_ptr error_reporter(new PythonErrorReporter); + ::tflite::python::ImportNumpy(); + + if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) { + return nullptr; + } + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + if (!model) { + PyErr_Format(PyExc_ValueError, "Invalid model"); + return nullptr; + } + flatbuffers::FlatBufferBuilder builder; + auto tflite_model = CreateMutableModel(*model->GetModel()); + if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) != + kTfLiteOk) { + error_reporter->exception(); + return nullptr; + } + + if (builder.GetSize()) { + return python_utils::ConvertToPyString( + reinterpret_cast(builder.GetCurrentBufferPointer()), + builder.GetSize()); + } else { + // When AddIntermediateTensorsToFusedOp early returns, return the model as + // it is. + return python_utils::ConvertToPyString(buf, length); + } +} + CalibrationWrapper::CalibrationWrapper( std::unique_ptr interpreter, std::unique_ptr resolver, diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 94aa0ed6f7f..4c81499c10c 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -50,6 +50,8 @@ class CalibrationReader; namespace calibration_wrapper { +PyObject* AddIntermediateTensors(PyObject* data); + class CalibrationWrapper { public: // SWIG caller takes ownership of pointer. diff --git a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc index 3f366615edc..5296d2796ab 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/python/lib/core/pybind11_lib.h" namespace py = pybind11; +using tflite::calibration_wrapper::AddIntermediateTensors; using tflite::calibration_wrapper::CalibrationWrapper; PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { @@ -25,6 +26,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { _pywrap_tensorflow_lite_calibration_wrapper ----- )pbdoc"; + m.def("AddIntermediateTensors", [](py::handle& data) { + return tensorflow::PyoOrThrow(AddIntermediateTensors(data.ptr())); + }); py::class_(m, "CalibrationWrapper") .def(py::init([](py::handle& data) { return ::CalibrationWrapper::CreateWrapperCPPFromBuffer(data.ptr()); diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index dfef8b9cb79..e1758e87eeb 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -31,6 +31,11 @@ _calibration_wrapper = LazyLoader( "_pywrap_tensorflow_lite_calibration_wrapper") +def add_intermediate_tensors(model_content): + """Adds intermedaite tensors to fused op if needed.""" + return _calibration_wrapper.AddIntermediateTensors(model_content) + + class Calibrator(object): """Calibrates a floating point model and then quantizes it. diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index a9ab12c6095..49fafa0ff0a 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -199,5 +199,13 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): quantized_model = quantizer.calibrate(input_gen) self.assertIsNotNone(quantized_model) + def test_add_intermediate_tensors(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + model = open(model_path, 'rb').read() + added_model = _calibrator.add_intermediate_tensors(model) + self.assertIsNotNone(added_model) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/lite/python/schema_util.py b/tensorflow/lite/python/schema_util.py new file mode 100644 index 00000000000..ea4092810f8 --- /dev/null +++ b/tensorflow/lite/python/schema_util.py @@ -0,0 +1,50 @@ +# Lint as: python2, python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Schema utilities to get builtin code from operator code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util import all_util + + +def get_builtin_code_from_operator_code(opcode): + """Return the builtin code of the given operator code. + + The following method is introduced to resolve op builtin code shortage + problem. The new builtin opreator will be assigned to the extended builtin + code field in the flatbuffer schema. Those methods helps to hide builtin code + details. + + Args: + opcode: Operator code. + + Returns: + The builtin code of the given operator code. + """ + # Access BuiltinCode() method first if available. + if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode): + return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode()) + + return max(opcode.builtinCode, opcode.deprecatedBuiltinCode) + + +_allowed_symbols = [ + 'get_builtin_code_from_operator_code', +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index d98e401c76a..83f2d14666b 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/lite:build_def.bzl", "tf_to_tflite") +load("//tensorflow/lite:build_def.bzl", "DEPRECATED_tf_to_tflite") load("//tensorflow:tensorflow.bzl", "pybind_extension") package( @@ -8,7 +8,7 @@ package( exports_files(glob(["*.pb"])) -tf_to_tflite( +DEPRECATED_tf_to_tflite( name = "permute_float", src = "permute.pbtxt", out = "permute_float.tflite", @@ -18,7 +18,7 @@ tf_to_tflite( ], ) -tf_to_tflite( +DEPRECATED_tf_to_tflite( name = "permute_uint8", src = "permute.pbtxt", out = "permute_uint8.tflite", @@ -33,7 +33,7 @@ tf_to_tflite( ], ) -tf_to_tflite( +DEPRECATED_tf_to_tflite( name = "gather_string", src = "gather.pbtxt", out = "gather_string.tflite", @@ -43,11 +43,22 @@ tf_to_tflite( ], ) +DEPRECATED_tf_to_tflite( + name = "gather_string_0d", + src = "gather_0d.pbtxt", + out = "gather_string_0d.tflite", + options = [ + "--input_arrays=input,indices", + "--output_arrays=output", + ], +) + filegroup( name = "interpreter_test_data", srcs = [ "pc_conv.bin", ":gather_string", + ":gather_string_0d", ":permute_float", ":permute_uint8", ], diff --git a/tensorflow/lite/python/testdata/gather_0d.pbtxt b/tensorflow/lite/python/testdata/gather_0d.pbtxt new file mode 100644 index 00000000000..b065cb22a4e --- /dev/null +++ b/tensorflow/lite/python/testdata/gather_0d.pbtxt @@ -0,0 +1,108 @@ +node { + name: "input" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } +} +node { + name: "input_const" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "abcd" + } + } + } +} +node { + name: "indices" + op: "Placeholder" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + } + } + } +} +node { + name: "axis" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "output" + op: "GatherV2" + input: "input_const" + input: "indices" + input: "axis" + device: "/device:CPU:0" + attr { + key: "Taxis" + value { + type: DT_INT32 + } + } + attr { + key: "Tindices" + value { + type: DT_INT64 + } + } + attr { + key: "Tparams" + value { + type: DT_STRING + } + } +} +versions { + producer: 27 +} diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 4fc00778b74..fac9fb2b435 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -32,6 +32,7 @@ from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.lite.python import schema_py_generated as schema_fb +from tensorflow.lite.python import schema_util from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes from tensorflow.lite.toco import types_pb2 as _types_pb2 @@ -76,7 +77,10 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = { _TFLITE_FILE_IDENTIFIER = b"TFL3" -_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (dtypes.float32, dtypes.int8, dtypes.uint8) +_MAP_QUANT_TO_IO_TYPES = { + dtypes.int8: {dtypes.int8, dtypes.uint8}, + dtypes.int16: {dtypes.int16}, +} def convert_dtype_to_tflite_type(tf_dtype): @@ -373,7 +377,7 @@ def build_debug_info_func(original_graph): (func, sub_func.graph.get_operation_by_name(name))) else: sys.stderr.write( - "Use '@tf.function' or '@defun' to decorate the function.") + "Use '@tf.function' or '@defun' to decorate the function.\n") continue except KeyError: # New node created by graph optimizer. No stack trace from source code. @@ -631,13 +635,6 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32): if inference_input_type == dtypes.float32: return - if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: - raise ValueError( - "Unsupported `inference_output_type` value. Expected to be in {}, " - "instead got {}.".format(tuple(_get_tf_type_name(t) for t in - _TFLITE_MODEL_INPUT_OUTPUT_TYPES), - _get_tf_type_name(inference_input_type))) - subgraph = model.subgraphs[0] tensors = subgraph.tensors operators = subgraph.operators @@ -645,30 +642,44 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32): # Find all quantize operators quant_opcode_idxs = [] for idx, opcode in enumerate(model.operatorCodes): - if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE: + builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) + if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: quant_opcode_idxs.append(idx) if not quant_opcode_idxs: raise ValueError("Model input is not quantized.") - # Ensure that the model input is quantized + # Validate that the model input is quantized input_quant_ops = [] for op in operators: - # Check if the operator quantizes an input + # Find operators that quantize model input if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: - # If found, validate the operator input/output tensor types - float_tensor, int_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] - if float_tensor.type != schema_fb.TensorType.FLOAT32: + float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] + # If found, validate that the operator's input type is float + float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) + if float_type != dtypes.float32: raise ValueError( - "Model input type must be tf.float32. Expected type for tensor " - "with name '{}' is tf.float32, instead type is {}".format( - float_tensor.name, _get_tf_type_name( - _convert_tflite_enum_type_to_tf_type(float_tensor.type)))) - if int_tensor.type != schema_fb.TensorType.INT8: + "Initial model input type must be tf.float32. Expected type for " + "tensor with name '{}' is tf.float32, instead type is {}".format( + float_tensor.name, _get_tf_type_name(float_type))) + # If found, validate that the operator output is quantized and compatible + # with the final model input type + quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) + if quant_type not in _MAP_QUANT_TO_IO_TYPES: raise ValueError( - "Model input is not quantized. Expected type for tensor " - "with name '{}' is tf.int8, instead type is {}".format( - int_tensor.name, _get_tf_type_name( - _convert_tflite_enum_type_to_tf_type(int_tensor.type)))) + "Initial model input is not quantized. Expected type for " + "tensor with name '{}' should be in {}, instead type is {}".format( + quant_tensor.name, + tuple(_get_tf_type_name(t) for t in + _MAP_QUANT_TO_IO_TYPES.keys()), + _get_tf_type_name(quant_type))) + else: + inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] + if inference_input_type not in inference_io_types: + raise ValueError( + "Unsupported `inference_input_type` value. Expected to be in " + "{}, instead got {}.".format( + tuple(_get_tf_type_name(t) for t in inference_io_types), + _get_tf_type_name(inference_input_type))) input_quant_ops.append(op) if len(subgraph.inputs) != len(input_quant_ops): @@ -684,7 +695,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32): uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] tensors[op.inputs[0]].quantization = uint8_quantization tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 - elif inference_input_type == dtypes.int8: + elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: # Remove the inputs and the quant operator remove_tensors_idxs = set() for op in input_quant_ops: @@ -695,10 +706,8 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32): _remove_tensors_from_model(model, remove_tensors_idxs) else: raise ValueError( - "Unsupported `inference_input_type` value. Expected to be in {}, " - "instead got {}.".format(tuple(_get_tf_type_name(t) for t in - _TFLITE_MODEL_INPUT_OUTPUT_TYPES), - _get_tf_type_name(inference_input_type))) + "Unsupported `inference_input_type` value {}.".format( + _get_tf_type_name(inference_input_type))) def _modify_model_output_type(model, inference_output_type=dtypes.float32): @@ -707,13 +716,6 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): if inference_output_type == dtypes.float32: return - if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: - raise ValueError( - "Unsupported `inference_output_type` value. Expected to be in {}, " - "instead got {}.".format(tuple(_get_tf_type_name(t) for t in - _TFLITE_MODEL_INPUT_OUTPUT_TYPES), - _get_tf_type_name(inference_output_type))) - subgraph = model.subgraphs[0] tensors = subgraph.tensors operators = subgraph.operators @@ -721,31 +723,45 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): # Find all dequantize operators dequant_opcode_idxs = [] for idx, opcode in enumerate(model.operatorCodes): - if opcode.builtinCode == schema_fb.BuiltinOperator.DEQUANTIZE: + builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) + if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: dequant_opcode_idxs.append(idx) if not dequant_opcode_idxs: raise ValueError("Model output is not dequantized.") - # Ensure that the model output is dequantized + # Validate that the model output is dequantized output_dequant_ops = [] for op in operators: - # Check if the operator dequantizes an output + # Find operators that dequantize model output if op.opcodeIndex in dequant_opcode_idxs and \ op.outputs[0] in subgraph.outputs: - # If found, validate the operator input/output tensor types - int_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] - if float_tensor.type != schema_fb.TensorType.FLOAT32: + # If found, validate that the operator's output type is float + quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] + float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) + if float_type != dtypes.float32: raise ValueError( - "Model output type must be tf.float32. Expected type for tensor " - "with name '{}' is tf.float32, instead type is {}".format( - float_tensor.name, _get_tf_type_name( - _convert_tflite_enum_type_to_tf_type(float_tensor.type)))) - if int_tensor.type != schema_fb.TensorType.INT8: + "Initial model output type must be tf.float32. Expected type for " + "tensor with name '{}' is tf.float32, instead type is {}".format( + float_tensor.name, _get_tf_type_name(float_type))) + # If found, validate that the operator input is quantized and compatible + # with the final model output type + quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) + if quant_type not in _MAP_QUANT_TO_IO_TYPES: raise ValueError( - "Model output is not dequantized. Expected type for tensor " - "with name '{}' is tf.int8, instead type is {}".format( - int_tensor.name, _get_tf_type_name( - _convert_tflite_enum_type_to_tf_type(int_tensor.type)))) + "Initial model output is not dequantized. Expected type for " + "tensor with name '{}' should be in {}, instead type is {}".format( + quant_tensor.name, + tuple(_get_tf_type_name(t) for t in + _MAP_QUANT_TO_IO_TYPES.keys()), + _get_tf_type_name(quant_type))) + else: + inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] + if inference_output_type not in inference_io_types: + raise ValueError( + "Unsupported `inference_output_type` value. Expected to be in " + "{}, instead got {}.".format( + tuple(_get_tf_type_name(t) for t in inference_io_types), + _get_tf_type_name(inference_output_type))) output_dequant_ops.append(op) if len(subgraph.outputs) != len(output_dequant_ops): @@ -756,7 +772,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): # Find a quantize operator quant_opcode_idx = -1 for idx, opcode in enumerate(model.operatorCodes): - if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE: + builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) + if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: quant_opcode_idx = idx break # Create a quantize operator, if none exist @@ -775,7 +792,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] tensors[op.outputs[0]].quantization = uint8_quantization tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 - elif inference_output_type == dtypes.int8: + elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: # Remove the outputs and the dequant operator remove_tensors_idxs = set() for op in output_dequant_ops: @@ -786,10 +803,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): _remove_tensors_from_model(model, remove_tensors_idxs) else: raise ValueError( - "Unsupported `inference_output_type` value. Expected to be in {}, " - "instead got {}.".format(tuple(_get_tf_type_name(t) for t in - _TFLITE_MODEL_INPUT_OUTPUT_TYPES), - _get_tf_type_name(inference_output_type))) + "Unsupported `inference_output_type` value {}.".format( + _get_tf_type_name(inference_output_type))) def modify_model_io_type( @@ -801,11 +816,12 @@ def modify_model_io_type( model: A tflite model. inference_input_type: tf.DType representing modified input type. (default tf.float32. If model input is int8 quantized, it must be in - {tf.float32, tf.int8, tf.uint8}, else it must be tf.float32) + {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized, + it must be in {tf.float32, tf.int16}, else it must be tf.float32) inference_output_type: tf.DType representing modified output type. (default tf.float32. If model output is int8 dequantized, it must be in - {tf.float32, tf.int8, tf.uint8}, else it must be tf.float32) - + {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, + it must be in {tf.float32, tf.int16}, else it must be tf.float32) Returns: A tflite model with modified input/output type. @@ -831,4 +847,3 @@ def modify_model_io_type( _modify_model_output_type(model_object, inference_output_type) return _convert_model_from_object_to_bytearray(model_object) - diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index 528caeab7d3..e98b50de0de 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -230,7 +230,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): self.assertAllEqual([None, 3, 5], tensor.shape) -def _generate_integer_tflite_model(): +def _generate_integer_tflite_model(quantization_type=dtypes.int8): """Define an integer post-training quantized tflite model.""" # Load MNIST dataset n = 10 # Number of samples @@ -276,7 +276,13 @@ def _generate_integer_tflite_model(): np.float32) ] converter.representative_dataset = representative_dataset_gen - converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8} + if quantization_type == dtypes.int8: + converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8} + else: + converter.target_spec.supported_ops = { + tf.lite.OpsSet + .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + } tflite_model = converter.convert() return tflite_model @@ -285,22 +291,24 @@ def _generate_integer_tflite_model(): def _test_param_modify_integer_model_io_type(): """Function to generate parameterized inputs for testing.""" params = [] - str_template = "_{}{}{}" + str_template = "_{}{}{}{}" map_model_type = { "PostTraining": True, # "DuringTraining": False, } - map_types = { - "": dtypes.float32, - "INT8": dtypes.int8, - "UINT8": dtypes.uint8, + map_quantize_type_to_io_types = { + tf.int8: {tf.float32, tf.int8, tf.uint8}, + tf.int16: {tf.float32, tf.int16} } for k1, v1 in map_model_type.items(): - for k2, v2 in map_types.items(): - istr = "_Input{}".format(k2) if k2 else "" - for k3, v3 in map_types.items(): - ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate" - params.append((str_template.format(k1, istr, ostr), v1, v2, v3)) + for qtype, v2 in map_quantize_type_to_io_types.items(): + qstr = "_IntegerQuantize{}".format(qtype.name.capitalize()) + for itype in v2: + istr = "_Input{}".format(itype.name.capitalize()) + for otype in v2: + ostr = "_Output{}".format(otype.name.capitalize()) + params.append((str_template.format(k1, qstr, istr, ostr), + v1, qtype, itype, otype)) return params @@ -311,10 +319,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest( @classmethod def setUpClass(cls): super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass() - cls.post_train_integer_model = _generate_integer_tflite_model() + cls.post_train_int8_model = _generate_integer_tflite_model() + cls.post_train_int16_model = _generate_integer_tflite_model( + quantization_type=dtypes.int16) @parameterized.named_parameters(_test_param_modify_integer_model_io_type()) - def test(self, is_post_train, in_tftype, out_tftype): + def test(self, is_post_train, quantization_type, in_tftype, out_tftype): """Modify the float input/output type of an integer quantized model.""" def _run_tflite_inference(model, in_tftype, out_tftype): @@ -353,7 +363,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest( return output_data - model = self.__class__.post_train_integer_model if is_post_train else None + if is_post_train and quantization_type == tf.int8: + model = self.__class__.post_train_int8_model + elif is_post_train and quantization_type == tf.int16: + model = self.__class__.post_train_int16_model + else: + model = None # Run model inference with float input output type output_data = _run_tflite_inference(model, tf.float32, tf.float32) # Run model inference with modified integer input output type diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index 13a996cf56e..a3e0952d627 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -139,7 +139,7 @@ cc_library( deps = [ ":schema_fbs", "//tensorflow/lite/kernels/internal:compatibility", - "@flatbuffers", + "@flatbuffers//:runtime_cc", ], ) diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 71312a7b016..62045344755 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -351,7 +351,8 @@ enum BuiltinOperator : int32 { DENSIFY = 124, SEGMENT_SUM = 125, BATCH_MATMUL = 126, - PLACEHOLDER_FOR_GREATER_OP_CODES = 127 + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128 } @@ -457,7 +458,8 @@ union BuiltinOptions { SelectV2Options, DensifyOptions, SegmentSumOptions, - BatchMatMulOptions + BatchMatMulOptions, + CumsumOptions, } enum Padding : byte { SAME, VALID } @@ -981,6 +983,11 @@ table BatchMatMulOptions { adj_y:bool; } +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index bbc000cc5dc..e7d91a93a99 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -349,6 +349,9 @@ struct SegmentSumOptionsT; struct BatchMatMulOptions; struct BatchMatMulOptionsT; +struct CumsumOptions; +struct CumsumOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -364,6 +367,12 @@ struct BufferT; struct Metadata; struct MetadataT; +struct TensorMap; +struct TensorMapT; + +struct SignatureDef; +struct SignatureDefT; + struct Model; struct ModelT; @@ -782,11 +791,12 @@ enum BuiltinOperator { BuiltinOperator_SEGMENT_SUM = 125, BuiltinOperator_BATCH_MATMUL = 126, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + BuiltinOperator_CUMSUM = 128, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES + BuiltinOperator_MAX = BuiltinOperator_CUMSUM }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[128] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[129] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -915,13 +925,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[128] { BuiltinOperator_DENSIFY, BuiltinOperator_SEGMENT_SUM, BuiltinOperator_BATCH_MATMUL, - BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES, + BuiltinOperator_CUMSUM }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[129] = { + static const char * const names[130] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1050,13 +1061,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "SEGMENT_SUM", "BATCH_MATMUL", "PLACEHOLDER_FOR_GREATER_OP_CODES", + "CUMSUM", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CUMSUM)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1164,11 +1176,12 @@ enum BuiltinOptions { BuiltinOptions_DensifyOptions = 99, BuiltinOptions_SegmentSumOptions = 100, BuiltinOptions_BatchMatMulOptions = 101, + BuiltinOptions_CumsumOptions = 102, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions + BuiltinOptions_MAX = BuiltinOptions_CumsumOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[103] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1271,13 +1284,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { BuiltinOptions_SelectV2Options, BuiltinOptions_DensifyOptions, BuiltinOptions_SegmentSumOptions, - BuiltinOptions_BatchMatMulOptions + BuiltinOptions_BatchMatMulOptions, + BuiltinOptions_CumsumOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[103] = { + static const char * const names[104] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1380,13 +1394,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "DensifyOptions", "SegmentSumOptions", "BatchMatMulOptions", + "CumsumOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CumsumOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1799,6 +1814,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2639,6 +2658,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_BatchMatMulOptions ? reinterpret_cast(value) : nullptr; } + tflite::CumsumOptionsT *AsCumsumOptions() { + return type == BuiltinOptions_CumsumOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CumsumOptionsT *AsCumsumOptions() const { + return type == BuiltinOptions_CumsumOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -9337,6 +9364,72 @@ inline flatbuffers::Offset CreateBatchMatMulOptions( flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct CumsumOptionsT : public flatbuffers::NativeTable { + typedef CumsumOptions TableType; + bool exclusive; + bool reverse; + CumsumOptionsT() + : exclusive(false), + reverse(false) { + } +}; + +struct CumsumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CumsumOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_EXCLUSIVE = 4, + VT_REVERSE = 6 + }; + bool exclusive() const { + return GetField(VT_EXCLUSIVE, 0) != 0; + } + bool reverse() const { + return GetField(VT_REVERSE, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_EXCLUSIVE) && + VerifyField(verifier, VT_REVERSE) && + verifier.EndTable(); + } + CumsumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CumsumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CumsumOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_exclusive(bool exclusive) { + fbb_.AddElement(CumsumOptions::VT_EXCLUSIVE, static_cast(exclusive), 0); + } + void add_reverse(bool reverse) { + fbb_.AddElement(CumsumOptions::VT_REVERSE, static_cast(reverse), 0); + } + explicit CumsumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CumsumOptionsBuilder &operator=(const CumsumOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCumsumOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool exclusive = false, + bool reverse = false) { + CumsumOptionsBuilder builder_(_fbb); + builder_.add_reverse(reverse); + builder_.add_exclusive(exclusive); + return builder_.Finish(); +} + +flatbuffers::Offset CreateCumsumOptions(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; int8_t deprecated_builtin_code; @@ -9790,6 +9883,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const { return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::CumsumOptions *builtin_options_as_CumsumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CumsumOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -10230,6 +10326,10 @@ template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as return builtin_options_as_BatchMatMulOptions(); } +template<> inline const tflite::CumsumOptions *Operator::builtin_options_as() const { + return builtin_options_as_CumsumOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -10593,6 +10693,193 @@ inline flatbuffers::Offset CreateMetadataDirect( flatbuffers::Offset CreateMetadata(flatbuffers::FlatBufferBuilder &_fbb, const MetadataT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct TensorMapT : public flatbuffers::NativeTable { + typedef TensorMap TableType; + std::string name; + uint32_t tensor_index; + TensorMapT() + : tensor_index(0) { + } +}; + +struct TensorMap FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TensorMapT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_TENSOR_INDEX = 6 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + uint32_t tensor_index() const { + return GetField(VT_TENSOR_INDEX, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_TENSOR_INDEX) && + verifier.EndTable(); + } + TensorMapT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorMapT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TensorMapBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(TensorMap::VT_NAME, name); + } + void add_tensor_index(uint32_t tensor_index) { + fbb_.AddElement(TensorMap::VT_TENSOR_INDEX, tensor_index, 0); + } + explicit TensorMapBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TensorMapBuilder &operator=(const TensorMapBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTensorMap( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + uint32_t tensor_index = 0) { + TensorMapBuilder builder_(_fbb); + builder_.add_tensor_index(tensor_index); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTensorMapDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + uint32_t tensor_index = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateTensorMap( + _fbb, + name__, + tensor_index); +} + +flatbuffers::Offset CreateTensorMap(flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SignatureDefT : public flatbuffers::NativeTable { + typedef SignatureDef TableType; + std::vector> inputs; + std::vector> outputs; + std::string method_name; + std::string key; + SignatureDefT() { + } +}; + +struct SignatureDef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SignatureDefT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUTS = 4, + VT_OUTPUTS = 6, + VT_METHOD_NAME = 8, + VT_KEY = 10 + }; + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + const flatbuffers::String *method_name() const { + return GetPointer(VT_METHOD_NAME); + } + const flatbuffers::String *key() const { + return GetPointer(VT_KEY); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_METHOD_NAME) && + verifier.VerifyString(method_name()) && + VerifyOffset(verifier, VT_KEY) && + verifier.VerifyString(key()) && + verifier.EndTable(); + } + SignatureDefT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SignatureDefT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SignatureDefBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(SignatureDef::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(SignatureDef::VT_OUTPUTS, outputs); + } + void add_method_name(flatbuffers::Offset method_name) { + fbb_.AddOffset(SignatureDef::VT_METHOD_NAME, method_name); + } + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(SignatureDef::VT_KEY, key); + } + explicit SignatureDefBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SignatureDefBuilder &operator=(const SignatureDefBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSignatureDef( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0, + flatbuffers::Offset method_name = 0, + flatbuffers::Offset key = 0) { + SignatureDefBuilder builder_(_fbb); + builder_.add_key(key); + builder_.add_method_name(method_name); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSignatureDefDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr, + const char *method_name = nullptr, + const char *key = nullptr) { + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + auto method_name__ = method_name ? _fbb.CreateString(method_name) : 0; + auto key__ = key ? _fbb.CreateString(key) : 0; + return tflite::CreateSignatureDef( + _fbb, + inputs__, + outputs__, + method_name__, + key__); +} + +flatbuffers::Offset CreateSignatureDef(flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ModelT : public flatbuffers::NativeTable { typedef Model TableType; uint32_t version; @@ -10602,6 +10889,7 @@ struct ModelT : public flatbuffers::NativeTable { std::vector> buffers; std::vector metadata_buffer; std::vector> metadata; + std::vector> signature_defs; ModelT() : version(0) { } @@ -10616,7 +10904,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DESCRIPTION = 10, VT_BUFFERS = 12, VT_METADATA_BUFFER = 14, - VT_METADATA = 16 + VT_METADATA = 16, + VT_SIGNATURE_DEFS = 18 }; uint32_t version() const { return GetField(VT_VERSION, 0); @@ -10639,6 +10928,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *metadata() const { return GetPointer> *>(VT_METADATA); } + const flatbuffers::Vector> *signature_defs() const { + return GetPointer> *>(VT_SIGNATURE_DEFS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_VERSION) && @@ -10658,6 +10950,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_METADATA) && verifier.VerifyVector(metadata()) && verifier.VerifyVectorOfTables(metadata()) && + VerifyOffset(verifier, VT_SIGNATURE_DEFS) && + verifier.VerifyVector(signature_defs()) && + verifier.VerifyVectorOfTables(signature_defs()) && verifier.EndTable(); } ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -10689,6 +10984,9 @@ struct ModelBuilder { void add_metadata(flatbuffers::Offset>> metadata) { fbb_.AddOffset(Model::VT_METADATA, metadata); } + void add_signature_defs(flatbuffers::Offset>> signature_defs) { + fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs); + } explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -10709,8 +11007,10 @@ inline flatbuffers::Offset CreateModel( flatbuffers::Offset description = 0, flatbuffers::Offset>> buffers = 0, flatbuffers::Offset> metadata_buffer = 0, - flatbuffers::Offset>> metadata = 0) { + flatbuffers::Offset>> metadata = 0, + flatbuffers::Offset>> signature_defs = 0) { ModelBuilder builder_(_fbb); + builder_.add_signature_defs(signature_defs); builder_.add_metadata(metadata); builder_.add_metadata_buffer(metadata_buffer); builder_.add_buffers(buffers); @@ -10729,13 +11029,15 @@ inline flatbuffers::Offset CreateModelDirect( const char *description = nullptr, const std::vector> *buffers = nullptr, const std::vector *metadata_buffer = nullptr, - const std::vector> *metadata = nullptr) { + const std::vector> *metadata = nullptr, + const std::vector> *signature_defs = nullptr) { auto operator_codes__ = operator_codes ? _fbb.CreateVector>(*operator_codes) : 0; auto subgraphs__ = subgraphs ? _fbb.CreateVector>(*subgraphs) : 0; auto description__ = description ? _fbb.CreateString(description) : 0; auto buffers__ = buffers ? _fbb.CreateVector>(*buffers) : 0; auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector(*metadata_buffer) : 0; auto metadata__ = metadata ? _fbb.CreateVector>(*metadata) : 0; + auto signature_defs__ = signature_defs ? _fbb.CreateVector>(*signature_defs) : 0; return tflite::CreateModel( _fbb, version, @@ -10744,7 +11046,8 @@ inline flatbuffers::Offset CreateModelDirect( description__, buffers__, metadata_buffer__, - metadata__); + metadata__, + signature_defs__); } flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -10758,7 +11061,7 @@ inline CustomQuantizationT *CustomQuantization::UnPack(const flatbuffers::resolv inline void CustomQuantization::UnPackTo(CustomQuantizationT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom[_i] = _e->Get(_i); } } } + { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->custom.begin()); } } } inline flatbuffers::Offset CustomQuantization::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10882,7 +11185,7 @@ inline Uint8VectorT *Uint8Vector::UnPack(const flatbuffers::resolver_function_t inline void Uint8Vector::UnPackTo(Uint8VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } } + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->values.begin()); } } } inline flatbuffers::Offset Uint8Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13703,6 +14006,35 @@ inline flatbuffers::Offset CreateBatchMatMulOptions(flatbuff _adj_y); } +inline CumsumOptionsT *CumsumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CumsumOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CumsumOptions::UnPackTo(CumsumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = exclusive(); _o->exclusive = _e; } + { auto _e = reverse(); _o->reverse = _e; } +} + +inline flatbuffers::Offset CumsumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCumsumOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateCumsumOptions(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CumsumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _exclusive = _o->exclusive; + auto _reverse = _o->reverse; + return tflite::CreateCumsumOptions( + _fbb, + _exclusive, + _reverse); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -13752,7 +14084,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } } { auto _e = builtin_options_type(); _o->builtin_options.type = _e; } { auto _e = builtin_options(); if (_e) _o->builtin_options.value = tflite::BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); } - { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } } + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->custom_options.begin()); } } { auto _e = custom_options_format(); _o->custom_options_format = _e; } { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } } { auto _e = intermediates(); if (_e) { _o->intermediates.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->intermediates[_i] = _e->Get(_i); } } } @@ -13835,7 +14167,7 @@ inline BufferT *Buffer::UnPack(const flatbuffers::resolver_function_t *_resolver inline void Buffer::UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } } + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->data.begin()); } } } inline flatbuffers::Offset Buffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13882,6 +14214,70 @@ inline flatbuffers::Offset CreateMetadata(flatbuffers::FlatBufferBuild _buffer); } +inline TensorMapT *TensorMap::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TensorMapT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TensorMap::UnPackTo(TensorMapT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = tensor_index(); _o->tensor_index = _e; } +} + +inline flatbuffers::Offset TensorMap::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTensorMap(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTensorMap(flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TensorMapT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _tensor_index = _o->tensor_index; + return tflite::CreateTensorMap( + _fbb, + _name, + _tensor_index); +} + +inline SignatureDefT *SignatureDef::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SignatureDefT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SignatureDef::UnPackTo(SignatureDefT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = method_name(); if (_e) _o->method_name = _e->str(); } + { auto _e = key(); if (_e) _o->key = _e->str(); } +} + +inline flatbuffers::Offset SignatureDef::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSignatureDef(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSignatureDef(flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SignatureDefT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector> (_o->inputs.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorMap(*__va->__fbb, __va->__o->inputs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector> (_o->outputs.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorMap(*__va->__fbb, __va->__o->outputs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _method_name = _o->method_name.empty() ? 0 : _fbb.CreateString(_o->method_name); + auto _key = _o->key.empty() ? 0 : _fbb.CreateString(_o->key); + return tflite::CreateSignatureDef( + _fbb, + _inputs, + _outputs, + _method_name, + _key); +} + inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ModelT(); UnPackTo(_o, _resolver); @@ -13898,6 +14294,7 @@ inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t * { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } } { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } + { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } } inline flatbuffers::Offset Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13915,6 +14312,7 @@ inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_f auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0; auto _metadata = _o->metadata.size() ? _fbb.CreateVector> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _signature_defs = _o->signature_defs.size() ? _fbb.CreateVector> (_o->signature_defs.size(), [](size_t i, _VectorArgs *__va) { return CreateSignatureDef(*__va->__fbb, __va->__o->signature_defs[i].get(), __va->__rehasher); }, &_va ) : 0; return tflite::CreateModel( _fbb, _version, @@ -13923,7 +14321,8 @@ inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_f _description, _buffers, _metadata_buffer, - _metadata); + _metadata, + _signature_defs); } inline bool VerifyQuantizationDetails(flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type) { @@ -14515,6 +14914,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -14937,6 +15340,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15347,6 +15754,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(value); + return CreateCumsumOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -15757,6 +16168,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::BatchMatMulOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_CumsumOptions: { + value = new tflite::CumsumOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -16269,6 +16684,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/stderr_reporter_test.cc b/tensorflow/lite/stderr_reporter_test.cc new file mode 100644 index 00000000000..264b7f7b313 --- /dev/null +++ b/tensorflow/lite/stderr_reporter_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/stderr_reporter.h" + +#include + +#include + +namespace tflite { + +namespace { + +void CheckWritesToStderr(ErrorReporter *error_reporter) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + testing::internal::CaptureStderr(); +#endif + + // Run the code under test. + TF_LITE_REPORT_ERROR(error_reporter, "Test: %d", 42); + +#ifndef TF_LITE_STRIP_ERROR_STRINGS + EXPECT_EQ("ERROR: Test: 42\n", testing::internal::GetCapturedStderr()); +#endif +} + +TEST(StderrReporterTest, DefaultErrorReporter_WritesToStderr) { + CheckWritesToStderr(DefaultErrorReporter()); +} + +TEST(StderrReporterTest, StderrReporter_WritesToStderr) { + StderrReporter stderr_reporter; + CheckWritesToStderr(&stderr_reporter); +} + +} // namespace + +} // namespace tflite diff --git a/tensorflow/lite/testdata/unidirectional_sequence_lstm.bin b/tensorflow/lite/testdata/unidirectional_sequence_lstm.bin new file mode 100644 index 00000000000..42c96d14faa Binary files /dev/null and b/tensorflow/lite/testdata/unidirectional_sequence_lstm.bin differ diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py index 0803f4de600..48c19c49686 100644 --- a/tensorflow/lite/testing/toco_convert.py +++ b/tensorflow/lite/testing/toco_convert.py @@ -114,6 +114,7 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs): converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( graphdef_file.name, input_arrays, output_tensors, input_shapes) + converter.experimental_new_converter = options.use_experimental_converter converter.optimizations = [tf.lite.Optimize.DEFAULT] if fully_quantize: diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index b5a55cfa090..b7bbcf49563 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -662,10 +662,10 @@ TEST_F(VersionedOpExportTest, Export) { // different versions. EXPECT_EQ(2, operator_codes->size()); EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, - (*operator_codes)[0]->builtin_code()); + GetBuiltinCode((*operator_codes)[0])); EXPECT_EQ(1, (*operator_codes)[0]->version()); EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, - (*operator_codes)[1]->builtin_code()); + GetBuiltinCode((*operator_codes)[1])); EXPECT_EQ(2, (*operator_codes)[1]->version()); // Verify that the 2 operators points to the correct indices of the operation diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 22fa1ff1cea..078f139f19b 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -319,6 +319,7 @@ cc_library( hdrs = ["list_flex_ops.h"], deps = [ "//tensorflow/lite:framework", + "//tensorflow/lite/schema:schema_utils", "@jsoncpp_git//:jsoncpp", ], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 8917c254825..e04e1a12cd4 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -46,8 +46,17 @@ extern "C" { typedef enum TfLiteStatus { kTfLiteOk = 0, + + // Generally referring to an error in the runtime (i.e. interpreter) kTfLiteError = 1, + + // Generally referring to an error from a TfLiteDelegate itself. kTfLiteDelegateError = 2, + + // Generally referring to an error in applying a delegate due to + // incompatibility between runtime and delegate, e.g., this error is returned + // when trying to apply a TfLite delegate onto a model graph that's already + // immutable. kTfLiteApplicationError = 3 } TfLiteStatus; diff --git a/tensorflow/lite/tools/cmake/README.md b/tensorflow/lite/tools/cmake/README.md index 23ddd2d8093..159ed6d3343 100644 --- a/tensorflow/lite/tools/cmake/README.md +++ b/tensorflow/lite/tools/cmake/README.md @@ -4,13 +4,13 @@ This page describes how to build the TensorFlow Lite static library with CMake tool. The following instructions have been tested on Ubuntu 16.04.3 64-bit PC (AMD64) -and TensorFlow devel docker image +, TensorFlow devel docker image and Windows 10. [tensorflow/tensorflow:devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/). **Note:** This is an experimental that is subject to change. -**Note:** The following are not currently supported: Android, iOS, Tests and -Host Tools (i.e benchmark / analysis tools etc.) +**Note:** The following are not currently supported: iOS, Tests and +Host Tools (i.e analysis tools etc.) #### Step 1. Install CMake tool @@ -73,6 +73,12 @@ current directory. In the tflite_build directory, +```sh +cmake --build . -j -t benchmark_model +``` + +Or + ```sh make benchmark_model -j ``` diff --git a/tensorflow/lite/tools/list_flex_ops_no_kernel.cc b/tensorflow/lite/tools/list_flex_ops_no_kernel.cc index 68d40be1c9c..e90e3d75f22 100644 --- a/tensorflow/lite/tools/list_flex_ops_no_kernel.cc +++ b/tensorflow/lite/tools/list_flex_ops_no_kernel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "json/json.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/list_flex_ops.h" namespace tflite { @@ -40,7 +41,7 @@ void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) { for (int i = 0; i < operators->size(); ++i) { const tflite::Operator* op = operators->Get(i); const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index()); - if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM || + if (tflite::GetBuiltinCode(opcode) != tflite::BuiltinOperator_CUSTOM || !tflite::IsFlexOp(opcode->custom_code()->c_str())) { continue; } diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 9beb7a239e8..7157a7c1002 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -312,6 +312,8 @@ tf_cc_test( "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", "//tensorflow/lite/tools/optimize:testdata/transpose.bin", + "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", "//tensorflow/lite/tools/optimize:testdata/unpack.bin", ], tags = [ diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD index 53bd1bb4faf..2f315ca509a 100644 --- a/tensorflow/lite/tools/optimize/calibration/BUILD +++ b/tensorflow/lite/tools/optimize/calibration/BUILD @@ -65,6 +65,7 @@ tf_cc_test( data = [ "//tensorflow/lite:testdata/lstm.bin", "//tensorflow/lite:testdata/multi_add.bin", + "//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin", ], tags = [ "tflite_not_portable_android", diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index 6249649f4b7..bdf27d9a980 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -125,15 +125,14 @@ void CalculateLstmOutputCalibration( const float* output_gate, TfLiteFusedActivation activation, const float* projection_weights, const float* projection_bias, const float proj_clip, float* output_state, float* scratch, Logger* logger, - const std::vector& intermediate_tensor_indexes, - ErrorReporter* error_reporter) { + int intermediate_tensor_index, ErrorReporter* error_reporter) { tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch); tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, scratch); - logger->LogTensorValue(intermediate_tensor_indexes[4], scratch, - n_cell * n_batch, error_reporter); + logger->LogTensorValue(intermediate_tensor_index, scratch, n_cell * n_batch, + error_reporter); const bool use_projection = (projection_weights != nullptr); const bool use_projection_bias = (projection_bias != nullptr); @@ -252,7 +251,7 @@ inline void LstmStepCalibration( n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, params->activation, projection_weights_ptr, projection_bias_ptr, params->proj_clip, output_state_ptr, scratch2, logger, - intermediate_tensor_indexes, error_reporter); + intermediate_tensor_indexes[4], error_reporter); // Copy output state to the output. Note that the output's rows may not be // contiguous (output_batch_leading_dim != n_output). for (int b = 0; b < n_batch; b++) { @@ -462,10 +461,9 @@ struct OpData { // Resize the output, state tensors based on the sizes of the input tensors. // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. -TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, +TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, + LSTMType lstm_type, Logger* logger, ErrorReporter* error_reporter) { - const auto* params = static_cast(node->builtin_data); - const TfLiteTensor* input; TF_LITE_ENSURE_OK( context, GetInputSafe(context, node, @@ -573,10 +571,37 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, ops::builtin::lstm::full::kOutputTensor, &output)); std::vector intermediate_tensor_indexes(node->intermediates->size); + // LSTM expect 5 intermediate tensors. + TF_LITE_ENSURE_EQ(context, node->intermediates->size, 5); for (int i = 0; i < node->intermediates->size; ++i) { intermediate_tensor_indexes[i] = node->intermediates->data[i]; } + TfLiteLSTMParams lstm_params; + bool time_major = true; + switch (lstm_type) { + case LSTMType::kLSTM: { + lstm_params = *(static_cast(node->builtin_data)); + time_major = true; + break; + } + case LSTMType::kUnidirectionalSequenceLSTM: { + const auto* params = static_cast( + node->builtin_data); + // Copy out the LSTM specific params so they can be passed in the + // function. + lstm_params.activation = params->activation; + lstm_params.cell_clip = params->cell_clip; + lstm_params.proj_clip = params->proj_clip; + lstm_params.asymmetric_quantize_inputs = + params->asymmetric_quantize_inputs; + time_major = params->time_major; + break; + } + default: + return kTfLiteError; + } + switch (input_to_output_weights->type) { case kTfLiteFloat32: { return EvalCalibration( @@ -593,9 +618,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias, - projection_weights, projection_bias, params, + projection_weights, projection_bias, &lstm_params, /*forward_sequence=*/true, - /*time_major=*/true, + /*time_major=*/time_major, /*output_offset=*/0, scratch_buffer, output_state, cell_state, output, logger, intermediate_tensor_indexes, error_reporter); } @@ -612,7 +637,14 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger, TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node, Logger* logger, ErrorReporter* error_reporter) { - return lstm_eval(context, node, logger, error_reporter); + return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter); +} + +TfLiteStatus unidirectional_sequence_lstm_logging_kernel( + TfLiteContext* context, TfLiteNode* node, Logger* logger, + ErrorReporter* error_reporter) { + return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger, + error_reporter); } } // namespace builtin diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h index f3306bc0564..0a9e7095507 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h @@ -23,9 +23,18 @@ namespace optimize { namespace calibration { namespace builtin { +enum class LSTMType { + kLSTM, + kUnidirectionalSequenceLSTM, +}; + TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node, Logger* logger, ErrorReporter* error_reporter); +TfLiteStatus unidirectional_sequence_lstm_logging_kernel( + TfLiteContext* context, TfLiteNode* node, Logger* logger, + ErrorReporter* error_reporter); + } // namespace builtin } // namespace calibration } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.cc b/tensorflow/lite/tools/optimize/calibration/calibrator.cc index be8fad8a221..6cddbc53009 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.cc @@ -174,13 +174,17 @@ GlobalCalibratorRegistry* GetCalibratorRegistry() { // TODO(jianlijianli): extend this to support multiple recipe for the same // model. logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context, - TfLiteNode* node) { - const int lstm_number_input = 24; - if (node->inputs->size == lstm_number_input) { - // LSTM Op. - return tflite::optimize::calibration::builtin::lstm_logging_kernel; + TfLiteNode* node, + int builtin_op_code) { + switch (builtin_op_code) { + case BuiltinOperator_LSTM: + return tflite::optimize::calibration::builtin::lstm_logging_kernel; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + return tflite::optimize::calibration::builtin:: + unidirectional_sequence_lstm_logging_kernel; + default: + return nullptr; } - return nullptr; } // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs, @@ -203,7 +207,9 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(logger->LogTensorValue( i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter)); } - auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node); + auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code; + auto kernel_invoke_intermediate = + GetLoggingEvalFunc(context, node, builtin_op_code); TfLiteStatus status; if (kernel_invoke_intermediate == nullptr) { status = kernel_invoke(context, node); diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc index f0cd27ef620..c2e205f2a6e 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator_test.cc @@ -283,7 +283,7 @@ TEST(CalibratorTest, LSTM) { auto status = BuildLoggingInterpreter(*flatbuffer_model, ops::builtin::BuiltinOpResolver{}, &interpreter, &reader); - EXPECT_EQ(kTfLiteOk, status); + EXPECT_EQ(status, kTfLiteOk); auto readonly_model = flatbuffer_model->GetModel(); tflite::ModelT model; @@ -294,24 +294,17 @@ TEST(CalibratorTest, LSTM) { status = interpreter->AllocateTensors(); EXPECT_EQ(kTfLiteOk, status); - const std::vector lstm_input = { - 0.3, 0.2, 0.9, 0.8, 0.1, // - 0.1, 0.5, 0.2, 0.4, 0.2, // - 0.6, 0.9, 0.2, 0.5, 0.7, // - }; + const std::vector lstm_input = {0.3, 0.2}; int input_tensor_idx = interpreter->inputs()[0]; TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx); for (size_t j = 0; j < lstm_input.size(); j++) { tensor->data.f[j] = lstm_input[j]; } - // Invoke with update == true. - status = interpreter->Invoke(); - ASSERT_EQ(kTfLiteOk, status); + ASSERT_EQ(interpreter->Invoke(), kTfLiteOk); absl::flat_hash_map stats; - status = reader->GetTensorStatsAsMap(&stats); - EXPECT_EQ(kTfLiteOk, status); + EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk); // Check the results. const float eps = 1e-6f; @@ -344,6 +337,66 @@ TEST(CalibratorTest, LSTM) { } } +TEST(CalibratorTest, UnidirectionalSequenceLSTM) { + auto flatbuffer_model = ReadModel("unidirectional_sequence_lstm.bin"); + ASSERT_TRUE(flatbuffer_model); + std::unique_ptr interpreter; + std::unique_ptr reader; + auto status = BuildLoggingInterpreter(*flatbuffer_model, + ops::builtin::BuiltinOpResolver{}, + &interpreter, &reader); + EXPECT_EQ(kTfLiteOk, status); + + auto readonly_model = flatbuffer_model->GetModel(); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + + ASSERT_TRUE(interpreter); + ASSERT_TRUE(reader); + EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk); + const std::vector lstm_input = {0.3, 0.2, 0.9, 0.8}; + int input_tensor_idx = interpreter->inputs()[0]; + TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx); + for (size_t j = 0; j < lstm_input.size(); j++) { + tensor->data.f[j] = lstm_input[j]; + } + + ASSERT_EQ(interpreter->Invoke(), kTfLiteOk); + + absl::flat_hash_map stats; + EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk); + + // Check the results. + const float eps = 1e-6f; + const std::unordered_map + expected_calibration_result = { + // Input. + {0, {0.200000, 0.900000}}, + // State. + {18, {0.000000, 0.520999}}, + // State. + {19, {0.000000, 0.711364}}, + // Output. + {24, {0.247992, 0.520999}}, + // Intemediate_0. + {25, {0.080045, 0.824241}}, + // Intemediate_1. + {26, {0.080045, 0.824241}}, + // Intemediate_2. + {27, {0.080045, 0.824241}}, + // Intemediate_3. + {28, {0.080045, 0.824241}}, + // Intemediate_4. + {29, {0.000000, 0.413618}}, + }; + EXPECT_EQ(expected_calibration_result.size(), stats.size()); + for (const auto& e : stats) { + auto expected_result = expected_calibration_result.at(e.first); + EXPECT_NEAR(e.second.min, expected_result.min, eps); + EXPECT_NEAR(e.second.max, expected_result.max, eps); + } +} + } // namespace } // namespace calibration } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 6375016d527..6ec320c4144 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -44,7 +44,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index, model->subgraphs.at(subgraph_index)->operators[op_index].get(); op_variant.op_code = GetBuiltinCode(model->operator_codes[op->opcode_index].get()); - if (op_variant.op_code == BuiltinOperator_LSTM) { + if (op_variant.op_code == BuiltinOperator_LSTM || + op_variant.op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { if (op->inputs.size() == 5) { // The 5 input ("basic") LSTM is not supported in this tooling (yet). op_variant.is_quantizable = false; @@ -230,7 +231,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.version = 2; break; } - case BuiltinOperator_LSTM: { + case BuiltinOperator_LSTM: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { if (!op_variant.is_quantizable) { // Early exist for 5 input LSTM. // It is not supported in this tooling yet. diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc index 753cf99375a..45d27663444 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc @@ -68,6 +68,10 @@ TfLiteStatus LoadModel(const string& path, ModelT* model) { TfLiteStatus AddIntermediateTensorsToFusedOp( flatbuffers::FlatBufferBuilder* builder, ModelT* model) { + // Return early when the model has no operator. + if (model->subgraphs.size() == 1 && model->subgraphs[0]->operators.empty()) { + return kTfLiteOk; + } // Return early if the model already has intermediate tensors. if (IntermediateTensorExists(model)) { return kTfLiteOk; diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index e7f1c7a8bdf..5db624258f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -825,16 +825,20 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model, if (input.second.number_of_bits == 8 && input.second.symmetric == false) { TensorT* tensor = subgraph->tensors[index_global].get(); + if (tensor->quantization == nullptr) { + continue; + } if (utils::HasMinMax(tensor)) { utils::QuantizeActivation(tensor, activations_type, error_reporter); } else { - TF_LITE_REPORT_ERROR( - error_reporter, - "Unable to find min/max value for output %d in %s in " - "subgraph %d, node: %d", - tensor, EnumNameBuiltinOperator(op_code), subgraph_idx, - op_idx); + TF_LITE_REPORT_ERROR(error_reporter, + "Unable to find min/max value for " + "intermediate tensor %d in %s in " + "subgraph %d, node: %d", + index_local, + EnumNameBuiltinOperator(op_code), + subgraph_idx, op_idx); return kTfLiteError; } } else if (input.second.number_of_bits == 16 && diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 32a23033019..9afd163efd2 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -81,6 +81,44 @@ class QuantizeModelTest : public testing::Test { internal::FailOnErrorReporter error_reporter_; }; +void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { + ASSERT_EQ(model.subgraphs.size(), expected_model.subgraphs.size()); + for (size_t subgraph_idx = 0; subgraph_idx < model.subgraphs.size(); + subgraph_idx++) { + const auto graph = model.subgraphs[subgraph_idx].get(); + const auto expected_graph = expected_model.subgraphs[subgraph_idx].get(); + ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size()); + for (size_t i = 0; i < graph->tensors.size(); i++) { + const auto tensor = graph->tensors[i].get(); + const auto expected_tensor = expected_graph->tensors[i].get(); + EXPECT_EQ(tensor->buffer, expected_tensor->buffer); + EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable); + EXPECT_EQ(tensor->shape, expected_tensor->shape); + EXPECT_EQ(tensor->name, expected_tensor->name); + EXPECT_EQ(tensor->type, expected_tensor->type); + const auto quantization_params = tensor->quantization.get(); + const auto expected_quantization_params = + expected_tensor->quantization.get(); + if (quantization_params != nullptr || + expected_quantization_params != nullptr) { + EXPECT_NE(quantization_params, nullptr); + EXPECT_NE(expected_quantization_params, nullptr); + EXPECT_EQ(quantization_params->scale, + expected_quantization_params->scale); + EXPECT_EQ(quantization_params->zero_point, + expected_quantization_params->zero_point); + } + } + } + ASSERT_EQ(model.buffers.size(), expected_model.buffers.size()); + for (size_t buffer_idx = 0; buffer_idx < model.buffers.size(); ++buffer_idx) { + const auto buffer = model.buffers[buffer_idx].get()->data; + const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data; + EXPECT_EQ(buffer, expected_buffer); + } + // TODO(jianlijianli): Compare operators as well. +} + class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: @@ -1121,42 +1159,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); - // Comparison. - ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size()); - for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); - subgraph_idx++) { - const auto graph = model_.subgraphs[subgraph_idx].get(); - const auto expected_graph = expected_model.subgraphs[subgraph_idx].get(); - ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size()); - for (size_t i = 0; i < graph->tensors.size(); i++) { - const auto tensor = graph->tensors[i].get(); - const auto expected_tensor = expected_graph->tensors[i].get(); - EXPECT_EQ(tensor->buffer, expected_tensor->buffer); - EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable); - EXPECT_EQ(tensor->shape, expected_tensor->shape); - EXPECT_EQ(tensor->name, expected_tensor->name); - EXPECT_EQ(tensor->type, expected_tensor->type); - const auto quantization_params = tensor->quantization.get(); - const auto expected_quantization_params = - expected_tensor->quantization.get(); - if (quantization_params != nullptr || - expected_quantization_params != nullptr) { - EXPECT_NE(quantization_params, nullptr); - EXPECT_NE(expected_quantization_params, nullptr); - EXPECT_EQ(quantization_params->scale, - expected_quantization_params->scale); - EXPECT_EQ(quantization_params->zero_point, - expected_quantization_params->zero_point); - } - } - } - ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size()); - for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size(); - ++buffer_idx) { - const auto buffer = model_.buffers[buffer_idx].get()->data; - const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data; - EXPECT_EQ(buffer, expected_buffer); - } + ExpectSameModels(model_, expected_model); } class QuantizeLSTM2Test : public QuantizeModelTest { @@ -1181,42 +1184,34 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); - // Comparison. - ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size()); - for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); - subgraph_idx++) { - const auto graph = model_.subgraphs[subgraph_idx].get(); - const auto expected_graph = expected_model.subgraphs[subgraph_idx].get(); - ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size()); - for (size_t i = 0; i < graph->tensors.size(); i++) { - const auto tensor = graph->tensors[i].get(); - const auto expected_tensor = expected_graph->tensors[i].get(); - EXPECT_EQ(tensor->buffer, expected_tensor->buffer); - EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable); - EXPECT_EQ(tensor->shape, expected_tensor->shape); - EXPECT_EQ(tensor->name, expected_tensor->name); - EXPECT_EQ(tensor->type, expected_tensor->type); - const auto quantization_params = tensor->quantization.get(); - const auto expected_quantization_params = - expected_tensor->quantization.get(); - if (quantization_params != nullptr || - expected_quantization_params != nullptr) { - EXPECT_NE(quantization_params, nullptr); - EXPECT_NE(expected_quantization_params, nullptr); - EXPECT_EQ(quantization_params->scale, - expected_quantization_params->scale); - EXPECT_EQ(quantization_params->zero_point, - expected_quantization_params->zero_point); - } - } - } - ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size()); - for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size(); - ++buffer_idx) { - const auto buffer = model_.buffers[buffer_idx].get()->data; - const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data; - EXPECT_EQ(buffer, expected_buffer); + ExpectSameModels(model_, expected_model); +} + +class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { + protected: + QuantizeUnidirectionalSequenceLSTMTest() { + input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + readonly_model_ = input_model_->GetModel(); + readonly_model_->UnPackTo(&model_); } +}; + +TEST_F(QuantizeUnidirectionalSequenceLSTMTest, + VerifyUnidirectionalSequenceLSTM) { + // Quantize model. + auto status = QuantizeModelAllOperators( + &builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, false, + TensorType_INT8, &error_reporter_); + ASSERT_EQ(kTfLiteOk, status); + + // Read expected model. + auto expected_fb_model = + ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + auto expected_read_only_model = expected_fb_model->GetModel(); + ModelT expected_model; + expected_read_only_model->UnPackTo(&expected_model); + + ExpectSameModels(model_, expected_model); } class QuantizeSVDFTest : public QuantizeModelTest { diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index b22902a3e4b..5565fc4d657 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -57,6 +57,11 @@ const char* kModelPack = "pack.bin"; const char* kLstmCalibrated = "lstm_calibrated.bin"; const char* kLstmQuantized = "lstm_quantized.bin"; +const char* kUnidirectionalSequenceLstmCalibrated = + "unidirectional_sequence_lstm_calibrated.bin"; +const char* kUnidirectionalSequenceLstmQuantized = + "unidirectional_sequence_lstm_quantized.bin"; + const char* kModelWithMinimumOp = "minimum.bin"; const char* kModelWithMaximumOp = "maximum.bin"; const char* kLstmCalibrated2 = "lstm_calibrated2.bin"; diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index 99e8f0aedd3..4341a67d1ae 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -92,17 +92,20 @@ extern const char* kModelPack; extern const char* kLstmCalibrated; extern const char* kLstmQuantized; +// Test model with LSTM op that has peephole, without layer norm, without +// projection, without cifg. +extern const char* kLstmCalibrated2; +extern const char* kLstmQuantized2; + +extern const char* kUnidirectionalSequenceLstmCalibrated; +extern const char* kUnidirectionalSequenceLstmQuantized; + // Test model with a minimum op. extern const char* kModelWithMinimumOp; // Test model with a maximum op. extern const char* kModelWithMaximumOp; -// Test model with LSTM op that has peephole, without layer norm, without -// projection, without cifg. -extern const char* kLstmCalibrated2; -extern const char* kLstmQuantized2; - // Test model with a transpose op. extern const char* kModelWithTranspose; diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin new file mode 100644 index 00000000000..5712f85329d Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin differ diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin new file mode 100644 index 00000000000..3b547b4f4ec Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin differ diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index a4cc4e0c9a0..cc81fe49cbc 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -126,6 +126,25 @@ bool VerifyStringTensorBuffer(const Tensor& tensor, const Buffer& buffer, return true; } +bool CheckArraySegments(const DimensionMetadata* dim_metadata) { + if (dim_metadata->array_segments() == nullptr) { + return false; + } + switch (dim_metadata->array_segments_type()) { + case SparseIndexVector_Int32Vector: + return (dim_metadata->array_segments_as_Int32Vector()->values() != + nullptr); + case SparseIndexVector_Uint16Vector: + return (dim_metadata->array_segments_as_Uint16Vector()->values() != + nullptr); + case SparseIndexVector_Uint8Vector: + return (dim_metadata->array_segments_as_Uint8Vector()->values() != + nullptr); + default: + return false; + } +} + int GetSizeOfSegments(const DimensionMetadata* dim_metadata) { switch (dim_metadata->array_segments_type()) { case SparseIndexVector_Int32Vector: @@ -155,6 +174,25 @@ int GetValueOfSegmentsAt(const DimensionMetadata* dim_metadata, const int i) { } } +bool CheckArrayIndices(const DimensionMetadata* dim_metadata) { + if (dim_metadata->array_indices() == nullptr) { + return false; + } + switch (dim_metadata->array_indices_type()) { + case SparseIndexVector_Int32Vector: + return (dim_metadata->array_indices_as_Int32Vector()->values() != + nullptr); + case SparseIndexVector_Uint16Vector: + return (dim_metadata->array_indices_as_Uint16Vector()->values() != + nullptr); + case SparseIndexVector_Uint8Vector: + return (dim_metadata->array_indices_as_Uint8Vector()->values() != + nullptr); + default: + return false; + } +} + int GetSizeOfIndices(const DimensionMetadata* dim_metadata) { switch (dim_metadata->array_indices_type()) { case SparseIndexVector_Int32Vector: @@ -205,9 +243,8 @@ absl::optional VerifyAndCountElements( // Each index in a dense dimension is stored implicitly. num_elements *= dim_metadata->dense_size(); } else { - const auto* array_segments = dim_metadata->array_segments(); - const auto* array_indices = dim_metadata->array_indices(); - if (array_segments == nullptr || array_indices == nullptr) { + if (!CheckArraySegments(dim_metadata) || + !CheckArrayIndices(dim_metadata)) { return absl::nullopt; } diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index 4f89a6531f8..06ac1968f52 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs_with_mutable", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@flatbuffers", diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 4d41b7d13e9..8627c492c70 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace { @@ -621,9 +622,10 @@ TensorType GetTensorType(int32_t idx, const SubGraph* subgraph) { // options to decide op version. OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, const SubGraph* subgraph) { - OpSignature op_sig = {op_code->builtin_code()}; + auto builtin_code = GetBuiltinCode(op_code); + OpSignature op_sig = {builtin_code}; - switch (op_code->builtin_code()) { + switch (builtin_code) { case BuiltinOperator_DEPTHWISE_CONV_2D: { auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions(); if (conv_option) { @@ -797,14 +799,15 @@ void UpdateOpVersion(uint8_t* model_buffer_pointer) { OperatorCode* op_code = model->mutable_operator_codes()->GetMutableObject(op->opcode_index()); - if (op_code->builtin_code() != BuiltinOperator_CUSTOM) { + auto builtin_code = GetBuiltinCode(op_code); + if (builtin_code != BuiltinOperator_CUSTOM) { OpSignature op_sig = GetOpSignature(op_code, op, subgraph); // Update builtin operator version. int32_t op_ver = GetBuiltinOperatorVersion(op_sig); if (!op_code->mutate_version(op_ver)) { LOG(ERROR) << "Can't set operator " - << EnumNameBuiltinOperator(op_code->builtin_code()) - << " to version " << op_ver; + << EnumNameBuiltinOperator(builtin_code) << " to version " + << op_ver; } } } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 3c0a92f9df1..8dfd41f9b9d 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/schema/mutable/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace { @@ -318,6 +319,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"}, {{BuiltinOperator_RANK, 1}, "1.14.0"}, {{BuiltinOperator_WHILE, 1}, "1.15.0"}, + {{BuiltinOperator_CUMSUM, 1}, kPendingReleaseVersion}, }); std::pair version_key = {op_code, op_version}; @@ -339,7 +341,7 @@ void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer) { const OperatorCode* op_code = model->operator_codes()->Get(op->opcode_index()); std::string runtime_version = FindMinimumRuntimeVersionForOp( - op_code->builtin_code(), op_code->version()); + GetBuiltinCode(op_code), op_code->version()); if (runtime_version.empty() || runtime_version == kPendingReleaseVersion) { // In case we didn't find the current op in the map, or the operator diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index ff944035b22..08b43bb9d88 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -9,6 +9,7 @@ tensorflow/lite/micro/build_def.bzl tensorflow/python/autograph/core/config.py tensorflow/python/eager/benchmarks_test_base.py tensorflow/python/framework/tfrt_utils.py +tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py tensorflow/python/tpu/profiler/pip_package/BUILD tensorflow/python/tpu/profiler/pip_package/README tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh @@ -117,12 +118,13 @@ tensorflow/third_party/llvm/BUILD tensorflow/third_party/llvm/expand_cmake_vars.py tensorflow/third_party/llvm/llvm.autogenerated.BUILD tensorflow/third_party/llvm/llvm.bzl +tensorflow/third_party/llvm_openmp/BUILD +tensorflow/third_party/llvm_openmp/openmp.bzl tensorflow/third_party/lmdb.BUILD tensorflow/third_party/mkl/BUILD tensorflow/third_party/mkl/LICENSE tensorflow/third_party/mkl/MKL_LICENSE tensorflow/third_party/mkl/build_defs.bzl -tensorflow/third_party/mkl/mkl.BUILD tensorflow/third_party/mkl_dnn/LICENSE tensorflow/third_party/mkl_dnn/build_defs.bzl tensorflow/third_party/mkl_dnn/mkldnn.BUILD @@ -273,115 +275,16 @@ tensorflow/tools/build_info/BUILD tensorflow/tools/ci_build/horovod/gpu/nightly.sh tensorflow/tools/ci_build/release/common.sh tensorflow/tools/ci_build/release/common_win.bat -tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/build.sh -tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/release.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/nightly_release.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py35_full/release.sh -tensorflow/tools/ci_build/release/macos/cpu_py36_full/nightly_release.sh -tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip.sh -tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh -tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/nightly_release.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh -tensorflow/tools/ci_build/release/macos/cpu_py37_full/release.sh -tensorflow/tools/ci_build/release/macos/cpu_py38_full/nightly_release.sh -tensorflow/tools/ci_build/release/macos/cpu_py38_full/nonpip.sh -tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh tensorflow/tools/ci_build/release/ubuntu_16/custom_op/nightly.sh tensorflow/tools/ci_build/release/ubuntu_16/custom_op/release.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_pip_on_cpu/build.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nightly_release.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nonpip.sh -tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh -tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/cpu/build.sh -tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/gpu/build.sh tensorflow/tools/ci_build/release/ubuntu_16/sanity/build.sh tensorflow/tools/ci_build/release/ubuntu_16/tpu_py37_full/nonpip.sh -tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/nightly.bat -tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/release.bat -tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly.bat -tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/cpu_py35_full/release.bat -tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly.bat -tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/cpu_py36_full/release.bat tensorflow/tools/ci_build/release/windows/cpu_py36_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/cpu_py36_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly.bat -tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/cpu_py37_full/release.bat tensorflow/tools/ci_build/release/windows/cpu_py37_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/cpu_py37_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly.bat -tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/cpu_py38_full/release.bat tensorflow/tools/ci_build/release/windows/cpu_py38_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/nightly.bat -tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/release.bat -tensorflow/tools/ci_build/release/windows/gpu_pip_on_cpu/build.bat -tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly.bat -tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/gpu_py35_full/release.bat -tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly.bat -tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/gpu_py36_full/release.bat tensorflow/tools/ci_build/release/windows/gpu_py36_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/gpu_py36_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly.bat -tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/gpu_py37_full/release.bat tensorflow/tools/ci_build/release/windows/gpu_py37_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/gpu_py37_full/release_v1.bat -tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly.bat -tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly_release.bat -tensorflow/tools/ci_build/release/windows/gpu_py38_full/release.bat tensorflow/tools/ci_build/release/windows/gpu_py38_full/release_pip_rename.sh -tensorflow/tools/ci_build/release/windows/upload_nightly_pip/upload.sh tensorflow/tools/ci_build/remote/BUILD tensorflow/tools/def_file_filter/BUILD tensorflow/tools/def_file_filter/BUILD.tpl diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 214f7e13d5c..6b5b392a681 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -358,7 +358,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":platform", ":platform_test", @@ -377,7 +376,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":platform", ":platform_test", @@ -389,7 +387,6 @@ tf_py_test( size = "small", srcs = ["platform/flags_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -405,7 +402,6 @@ tf_py_test( "no_windows", "nomac", ], - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -710,7 +706,6 @@ tf_python_pybind_extension( "@pybind11", "//third_party/python_runtime:headers", "//tensorflow/core:protos_all_cc", - "//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs", "//tensorflow/core:framework_headers_lib", "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:lib_headers_for_pybind", @@ -1067,7 +1062,7 @@ cc_library( "-parse_headers", ], visibility = tf_external_workspace_visible(visibility + [ - "//learning/deepmind/courier:__subpackages__", + "//tensorflow:ndarray_tensor_allow_list", ]), deps = [ ":numpy_lib", @@ -1192,7 +1187,6 @@ tf_py_test( name = "decorator_utils_test", srcs = ["util/decorator_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -1204,7 +1198,6 @@ tf_py_test( name = "deprecation_test", srcs = ["util/deprecation_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -1216,7 +1209,6 @@ tf_py_test( name = "dispatch_test", srcs = ["util/dispatch_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -1228,7 +1220,6 @@ tf_py_test( name = "keyword_args_test", srcs = ["util/keyword_args_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -1383,6 +1374,7 @@ py_library( ":_pywrap_kernel_registry", ":_pywrap_py_exception_registry", ":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. + ":_pywrap_python_api_dispatcher", ":_pywrap_python_op_gen", ":_pywrap_quantize_training", ":_pywrap_stacktrace_handler", @@ -1566,7 +1558,6 @@ tf_py_test( srcs = ["framework/function_def_to_graph_test.py"], python_version = "PY3", tags = ["no_pip"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -1695,7 +1686,6 @@ tf_py_test( srcs = ["framework/py_context_manager_test.py"], python_version = "PY3", tags = ["no_pip"], - tfrt_enabled = True, deps = [ ":_py_context_manager", ], @@ -1750,7 +1740,45 @@ tf_py_test( srcs = ["framework/op_def_util_test.py"], python_version = "PY3", tags = ["no_pip"], - tfrt_enabled = True, +) + +cc_library( + name = "python_api_dispatcher", + srcs = ["framework/python_api_dispatcher.cc"], + hdrs = ["framework/python_api_dispatcher.h"], + deps = [ + ":cpp_python_util", + ":safe_pyobject_ptr", + "//tensorflow/core/platform:logging", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +# Note: this target is only used by python_api_dispatcher_test. +tf_python_pybind_extension( + name = "_pywrap_python_api_dispatcher", + # testonly = True, + srcs = ["framework/python_api_dispatcher_wrapper.cc"], + hdrs = ["framework/python_api_dispatcher.h"], + module_name = "_pywrap_python_api_dispatcher", + deps = [ + ":safe_pyobject_ptr_required_hdrs", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@pybind11", + ], +) + +tf_py_test( + name = "python_api_dispatcher_test", + srcs = ["framework/python_api_dispatcher_test.py"], + python_version = "PY3", + tags = ["no_pip"], + deps = [ + ":_pywrap_python_api_dispatcher", + ":client_testlib", + ], ) py_library( @@ -1964,7 +1992,6 @@ tf_py_test( size = "small", srcs = ["framework/smart_cond_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":constant_op", @@ -2037,7 +2064,6 @@ tf_py_test( srcs = ["framework/composite_tensor_utils_test.py"], main = "framework/composite_tensor_utils_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":composite_tensor", @@ -2300,7 +2326,6 @@ tf_py_test( srcs = ["framework/constant_op_test.py"], main = "framework/constant_op_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":constant_op", ], @@ -2312,7 +2337,6 @@ tf_py_test( srcs = ["framework/registry_test.py"], main = "framework/registry_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -2326,7 +2350,6 @@ tf_py_test( srcs = ["framework/errors_test.py"], main = "framework/errors_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":errors", @@ -2340,7 +2363,6 @@ tf_py_test( srcs = ["framework/error_interpolation_test.py"], main = "framework/error_interpolation_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":constant_op", @@ -2355,7 +2377,6 @@ tf_py_test( srcs = ["framework/subscribe_test.py"], main = "framework/subscribe_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework", ":framework_for_generated_wrappers", @@ -2398,7 +2419,6 @@ tf_py_test( tags = [ "no_pip", ], - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -2411,7 +2431,6 @@ tf_py_test( srcs = ["framework/proto_test.py"], main = "framework/proto_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -2500,7 +2519,6 @@ tf_py_test( srcs = ["framework/versions_test.py"], main = "framework/versions_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -2549,7 +2567,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -2573,7 +2590,6 @@ tf_py_test( srcs = ["framework/traceable_stack_test.py"], main = "framework/traceable_stack_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_test_lib", ":platform_test", @@ -2628,7 +2644,6 @@ tf_py_test( srcs = ["framework/common_shapes_test.py"], main = "framework/common_shapes_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework", ":framework_for_generated_wrappers", @@ -2645,7 +2660,6 @@ tf_py_test( main = "framework/ops_test.py", python_version = "PY3", tags = ["no_pip"], # test_ops_2 is not available in pip. - tfrt_enabled = True, deps = [ ":cond_v2", ":control_flow_ops", @@ -2675,7 +2689,6 @@ tf_py_test( srcs = ["framework/ops_enable_eager_test.py"], main = "framework/ops_enable_eager_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework", ":platform_test", @@ -2689,7 +2702,6 @@ tf_py_test( srcs = ["framework/tensor_shape_test.py"], main = "framework/tensor_shape_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2705,7 +2717,6 @@ tf_py_test( srcs = ["framework/type_spec_test.py"], main = "framework/type_spec_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2721,7 +2732,6 @@ tf_py_test( srcs = ["framework/tensor_spec_test.py"], main = "framework/tensor_spec_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2754,7 +2764,6 @@ tf_py_test( srcs = ["framework/device_spec_test.py"], main = "framework/device_spec_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2769,7 +2778,6 @@ tf_py_test( srcs = ["framework/device_test.py"], main = "framework/device_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2784,7 +2792,6 @@ tf_py_test( srcs = ["framework/random_seed_test.py"], main = "framework/random_seed_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework", @@ -2797,7 +2804,6 @@ tf_py_test( srcs = ["framework/tensor_shape_div_test.py"], main = "framework/tensor_shape_div_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2814,7 +2820,6 @@ tf_py_test( main = "framework/tensor_util_test.py", python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -2869,7 +2874,6 @@ tf_py_test( "nomsan", # TODO(b/149948895): Re-enable. "notsan", # TODO(b/149948895): Re-enable. ], - tfrt_enabled = True, deps = [ ":framework_test_lib", # TODO(kkb): Find more appropriate place to add `memory_checker` as deps @@ -2895,7 +2899,6 @@ tf_py_test( srcs = ["framework/dtypes_test.py"], main = "framework/dtypes_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2911,7 +2914,6 @@ tf_py_test( size = "small", srcs = ["framework/op_def_library_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -2925,7 +2927,6 @@ tf_py_test( srcs = ["framework/kernels_test.py"], main = "framework/kernels_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_test_lib", ":kernels", @@ -3467,7 +3468,6 @@ tf_py_test( size = "small", srcs = ["ops/clip_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":clip_ops", @@ -3493,7 +3493,6 @@ tf_py_test( size = "medium", srcs = ["ops/clustering_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":clustering_ops", @@ -3538,7 +3537,6 @@ tf_py_test( "no_windows", "nomac", ], - tfrt_enabled = True, xla_enable_strict_auto_jit = True, deps = [ ":client_testlib", @@ -3662,7 +3660,6 @@ tf_py_test( size = "small", srcs = ["ops/control_flow_v2_toggles_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":control_flow_util_v2", @@ -3676,7 +3673,6 @@ tf_py_test( size = "small", srcs = ["ops/control_flow_v2_enable_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":control_flow_util", @@ -3698,7 +3694,6 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":client_testlib", ":control_flow_util", @@ -3781,7 +3776,6 @@ tf_py_test( size = "small", srcs = ["ops/bincount_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":bincount_ops", ":platform_test", @@ -4316,6 +4310,7 @@ py_library( ":array_ops", ":handle_data_util", ":list_ops_gen", + "//third_party/py/numpy", ], ) @@ -4659,7 +4654,6 @@ tf_py_test( name = "sort_ops_test", srcs = ["ops/sort_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -4759,7 +4753,6 @@ cuda_py_test( name = "rnn_grad_test", srcs = ["ops/rnn_grad_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5042,7 +5035,6 @@ cuda_py_test( srcs = ["ops/bitwise_ops_test.py"], python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":bitwise_ops", ":constant_op", @@ -5100,7 +5092,6 @@ cuda_py_test( size = "medium", srcs = ["ops/gradient_checker_v2_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5155,7 +5146,6 @@ cuda_py_test( size = "small", srcs = ["ops/histogram_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5173,7 +5163,6 @@ cuda_py_test( srcs = ["ops/image_grad_deterministic_test.py"], python_version = "PY3", shard_count = 5, - tfrt_enabled = True, deps = [ ":image_grad_test_base", ], @@ -5185,7 +5174,6 @@ cuda_py_test( srcs = ["ops/image_grad_test.py"], python_version = "PY3", shard_count = 5, - tfrt_enabled = True, deps = [ ":image_grad_test_base", ], @@ -5210,6 +5198,9 @@ cuda_py_test( data = ["//tensorflow/core:image_testdata"], python_version = "PY3", shard_count = 16, + tags = [ + "no_cuda_asan", # TODO(b/171511582): re-enable. + ], deps = [ ":array_ops", ":client", @@ -5235,7 +5226,6 @@ cuda_py_test( size = "small", srcs = ["ops/init_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_ops", @@ -5251,7 +5241,6 @@ cuda_py_test( size = "medium", srcs = ["ops/init_ops_v2_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5269,7 +5258,6 @@ cuda_py_test( srcs = ["ops/math_grad_test.py"], python_version = "PY3", tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5309,7 +5297,6 @@ cuda_py_test( tags = [ "no_windows_gpu", ], - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -5326,7 +5313,6 @@ cuda_py_test( python_version = "PY3", shard_count = 4, tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5416,6 +5402,7 @@ py_test( ":array_ops", ":client", ":client_testlib", + "//third_party/py/numpy", ], ) @@ -5430,7 +5417,6 @@ cuda_py_test( "no_oss", # TODO(b/149565560) "no_windows_gpu", ], - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -5468,7 +5454,6 @@ tf_py_test( size = "small", srcs = ["ops/variable_spec_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -5515,7 +5500,6 @@ tf_py_test( name = "tf_export_test", srcs = ["util/tf_export_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":platform", @@ -5576,7 +5560,6 @@ tf_py_test( name = "tf_stack_test", srcs = ["util/tf_stack_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":tf_export", @@ -5639,7 +5622,6 @@ tf_py_test( "no_pip", # b/168621686 "no_windows", # b/169275019 ], - tfrt_enabled = True, deps = [ ":_function_parameter_canonicalizer_binding_for_test", ":client_testlib", @@ -5673,6 +5655,7 @@ py_library( "//tensorflow:__pkg__", "//third_party/py/tensorflow_core:__subpackages__", "//third_party/py/tf_agents:__subpackages__", + "//third_party/py/tfx:__subpackages__", ], deps = [ ":_pywrap_tensor_float_32_execution", @@ -5688,6 +5671,7 @@ py_library( "//third_party/py/numpy", "@six_archive//:six", "@wrapt", + "//tensorflow/tools/docs:doc_controls", "//tensorflow/tools/compatibility:all_renames_v2", ], ) @@ -5697,7 +5681,6 @@ tf_py_test( size = "small", srcs = ["util/object_identity_test.py"], python_version = "PY3", - tfrt_enabled = True, ) # Placeholder for intenal nest_test comments. @@ -5707,7 +5690,6 @@ tf_py_test( srcs = ["util/nest_test.py"], main = "util/nest_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [":util_nest_test_main_lib"], ) @@ -5733,7 +5715,6 @@ tf_py_test( srcs = ["util/serialization_test.py"], main = "util/serialization_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5744,7 +5725,6 @@ tf_py_test( name = "function_utils_test", srcs = ["util/function_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5756,7 +5736,6 @@ tf_py_test( size = "small", srcs = ["util/tf_contextlib_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5768,7 +5747,6 @@ tf_py_test( size = "small", srcs = ["util/tf_decorator_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5792,7 +5770,6 @@ tf_py_test( size = "small", srcs = ["util/tf_should_use_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":tf_should_use", @@ -5804,7 +5781,6 @@ tf_py_test( size = "small", srcs = ["util/tf_inspect_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5829,7 +5805,6 @@ tf_py_test( srcs = ["util/lock_util_test.py"], main = "util/lock_util_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5842,7 +5817,6 @@ tf_py_test( size = "small", srcs = ["util/module_wrapper_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":util", @@ -5886,7 +5860,6 @@ tf_py_test( main = "util/protobuf/compare_test.py", python_version = "PY3", tags = ["no_pip"], # compare_test_pb2 proto is not available in pip. - tfrt_enabled = True, deps = [ ":compare_test_proto_py", ":platform_test", @@ -5901,7 +5874,6 @@ tf_py_test( srcs = ["util/example_parser_configuration_test.py"], main = "util/example_parser_configuration_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -5917,7 +5889,6 @@ tf_py_test( size = "small", srcs = ["client/events_writer_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":errors", ":framework_test_lib", @@ -6047,12 +6018,16 @@ pywrap_tensorflow_macro( ":pybind11_lib", ":pybind11_status", ":pybind11_proto", + ":python_api_dispatcher", ":python_op_gen", ":safe_pyobject_ptr", ":tf_session_helper", "//third_party/python_runtime:headers", "//tensorflow/c:c_api", + "//tensorflow/c:kernels", + "//tensorflow/c:ops", "//tensorflow/c:c_api_experimental", + "//tensorflow/c/experimental/stream_executor:stream_executor", "//tensorflow/c:checkpoint_reader", "//tensorflow/c:python_api", "//tensorflow/c:tf_status_helper", @@ -6116,6 +6091,7 @@ filegroup( ":numpy_lib", # checkpoint_reader ":py_exception_registry", # py_exception_registry ":py_func_lib", # py_func + ":python_api_dispatcher", # python_api_dispatcher ":python_op_gen", # python_op_gen ":safe_ptr", # checkpoint_reader "//tensorflow/c:checkpoint_reader", # checkpoint_reader @@ -6432,7 +6408,6 @@ tf_py_test( "no_pip_gpu", "notsan", # data race due to b/62910646 ], - tfrt_enabled = True, deps = [ ":client", ":framework", @@ -6452,7 +6427,6 @@ tf_py_test( "no_gpu", "no_windows", ], - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6477,7 +6451,6 @@ cuda_py_test( "gpu_cupti", "no_gpu", # b/154742661 ], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Graph structure is different with autojit deps = [ ":client", @@ -6497,7 +6470,6 @@ cuda_py_test( "no_gpu", # b/127386241 "no_windows_gpu", ], - tfrt_enabled = True, deps = [ ":client", ":client_testlib", @@ -6512,7 +6484,6 @@ tf_py_test( size = "small", srcs = ["framework/c_api_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":c_api_util", ":framework_test_lib", @@ -6525,7 +6496,6 @@ tf_py_test( size = "small", srcs = ["framework/graph_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client", ":client_testlib", @@ -6560,7 +6530,6 @@ tf_py_test( size = "small", srcs = ["lib/core/bfloat16_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":lib", @@ -6577,7 +6546,6 @@ tf_py_test( "no_rocm", "no_windows", ], - tfrt_enabled = True, deps = [ ":client_testlib", ":errors", @@ -6590,7 +6558,6 @@ tf_py_test( size = "small", srcs = ["lib/io/tf_record_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":errors", @@ -6795,7 +6762,6 @@ cuda_py_test( main = "ops/accumulate_n_benchmark.py", python_version = "PY3", shard_count = 6, - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6815,7 +6781,6 @@ cuda_py_test( srcs = ["ops/batch_norm_benchmark.py"], main = "ops/batch_norm_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6837,7 +6802,6 @@ cuda_py_test( srcs = ["ops/collective_ops_benchmark.py"], main = "ops/collective_ops_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6855,7 +6819,6 @@ cuda_py_test( srcs = ["ops/concat_benchmark.py"], main = "ops/concat_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6874,7 +6837,6 @@ cuda_py_test( srcs = ["ops/control_flow_ops_benchmark.py"], main = "ops/control_flow_ops_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":constant_op", @@ -6890,7 +6852,6 @@ cuda_py_test( srcs = ["ops/conv2d_benchmark.py"], main = "ops/conv2d_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":client", ":client_testlib", @@ -6911,7 +6872,6 @@ cuda_py_test( srcs = ["ops/split_benchmark.py"], main = "ops/split_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6932,7 +6892,6 @@ cuda_py_test( srcs = ["ops/transpose_benchmark.py"], main = "ops/transpose_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6953,7 +6912,6 @@ cuda_py_test( srcs = ["ops/matmul_benchmark.py"], main = "ops/matmul_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [":matmul_benchmark_main_lib"], ) @@ -6983,7 +6941,6 @@ cuda_py_test( grpc_enabled = True, main = "client/session_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -7002,7 +6959,6 @@ cuda_py_test( srcs = ["framework/graph_building_benchmark.py"], main = "framework/graph_building_benchmark.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -7018,7 +6974,6 @@ cuda_py_test( size = "medium", srcs = ["ops/nn_grad_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -7073,7 +7028,6 @@ tf_py_test( "grappler", "no_pip", # tf_optimizer is not available in pip. ], - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -7094,7 +7048,6 @@ tf_py_test( "grappler", "no_pip", # tf_optimizer is not available in pip. ], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -7213,7 +7166,6 @@ tf_py_test( "grappler", "no_pip", # tf_optimizer is not available in pip. ], - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -7235,7 +7187,6 @@ tf_py_test( tags = [ "grappler", ], - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -7371,7 +7322,6 @@ tf_py_test( "no_pip", "no_windows", # TODO(b/151942037) ], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -7406,7 +7356,6 @@ tf_py_test( "grappler", "no_pip", ], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -7427,7 +7376,6 @@ cuda_py_test( ], python_version = "PY3", tags = ["grappler"], - tfrt_enabled = True, # This test analyzes the graph, but XLA changes the names of nodes. xla_enable_strict_auto_jit = False, deps = [ @@ -7610,6 +7558,9 @@ tf_python_pybind_extension( "//tensorflow/python/eager:pywrap_required_hdrs", ], module_name = "_pywrap_tfe", + # Only include TensorFlow header-only targets here. + # If a cc_library needs to depend on TensorFlow .cc files through srcs or + # deps, then you can use cc_header_only_library to keep only headers. deps = [ ":safe_pyobject_ptr", ":pybind11_lib", @@ -7620,14 +7571,14 @@ tf_python_pybind_extension( "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "//tensorflow/c:pywrap_required_hdrs", "@pybind11", "//third_party/python_runtime:headers", - "//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs", "//tensorflow/compiler/jit:flags_headers_only", "//tensorflow/compiler/jit:get_compiler_ir_hdrs_only", - "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal_hdrs_only", "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", @@ -7724,7 +7675,6 @@ cuda_py_test( name = "raw_ops_test", srcs = ["ops/raw_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ], diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD index 874b99464d8..f3b4dfcb558 100644 --- a/tensorflow/python/autograph/BUILD +++ b/tensorflow/python/autograph/BUILD @@ -21,7 +21,7 @@ py_strict_library( srcs = [ "__init__.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ "//tensorflow/python:util", diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index fd8ec1dbaa3..e3153d3f93c 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -35,7 +35,7 @@ py_library( "slices.py", "variables.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:util", @@ -52,7 +52,7 @@ py_test( name = "asserts_test", srcs = ["asserts_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -64,7 +64,7 @@ py_test( name = "break_statements_test", srcs = ["break_statements_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -88,7 +88,7 @@ py_test( name = "conditional_expressions_test", srcs = ["conditional_expressions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -100,7 +100,7 @@ py_test( name = "continue_statements_test", srcs = ["continue_statements_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -126,7 +126,7 @@ py_test( name = "directives_test", srcs = ["directives_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -151,7 +151,7 @@ py_test( name = "list_comprehensions_test", srcs = ["list_comprehensions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -163,7 +163,7 @@ py_test( name = "lists_test", srcs = ["lists_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -175,7 +175,7 @@ py_test( name = "logical_expressions_test", srcs = ["logical_expressions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = ["notsan"], # b/163218460 deps = [ ":converters", @@ -188,7 +188,7 @@ py_test( name = "return_statements_test", srcs = ["return_statements_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -201,7 +201,7 @@ py_test( name = "slices_test", srcs = ["slices_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -214,7 +214,7 @@ py_test( name = "variables_test", srcs = ["variables_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":converters", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 4eace00fcaf..7b30b5723be 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -33,10 +33,6 @@ from tensorflow.python.autograph.pyct.static_analysis import annos from tensorflow.python.autograph.pyct.static_analysis import liveness from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs -from tensorflow.python.autograph.utils import compat_util - - -# TODO(mdan): Refactor functions to make them smaller. class _Function(object): @@ -419,6 +415,3 @@ def transform(node, ctx): node = ControlFlowTransformer(ctx).visit(node) return node - - -compat_util.deprecated_py2_support(__name__) diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index 4a5c50dac55..77d34240d24 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -26,7 +26,7 @@ py_library( "function_wrappers.py", "unsupported_features_checker.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:framework_ops", @@ -43,7 +43,7 @@ py_library( srcs = [ "converter_testing.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ ":core", @@ -61,7 +61,7 @@ py_test( name = "converter_test", srcs = ["converter_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":core", ":test_lib", @@ -73,7 +73,7 @@ py_test( name = "function_wrappers_test", srcs = ["function_wrappers_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":core", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD index 4c5475bbb74..f7afd32f293 100644 --- a/tensorflow/python/autograph/impl/BUILD +++ b/tensorflow/python/autograph/impl/BUILD @@ -22,7 +22,7 @@ py_library( "api.py", "conversion.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:platform", @@ -54,7 +54,6 @@ tf_py_test( tf_py_test( name = "conversion_test", srcs = ["conversion_test.py"], - tfrt_enabled = True, deps = [ ":impl", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD index ceccc6f0c93..e0db4f0d5e2 100644 --- a/tensorflow/python/autograph/lang/BUILD +++ b/tensorflow/python/autograph/lang/BUILD @@ -22,7 +22,7 @@ py_library( "directives.py", "special_functions.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python/autograph/operators", @@ -33,7 +33,7 @@ py_test( name = "special_functions_test", srcs = ["special_functions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":lang", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 13b3b7a1764..ab9babf9149 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -22,7 +22,6 @@ py_library( "__init__.py", "conditional_expressions.py", "control_flow.py", - "control_flow_deprecated_py2.py", "data_structures.py", "exceptions.py", "logical.py", @@ -30,7 +29,7 @@ py_library( "slices.py", "variables.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:array_ops", @@ -54,7 +53,7 @@ py_test( name = "data_structures_test", srcs = ["data_structures_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":operators", "//tensorflow/python:client_testlib", @@ -66,9 +65,6 @@ py_test( srcs = ["conditional_expressions_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - ], deps = [ ":operators", "//tensorflow/python:client_testlib", @@ -82,7 +78,6 @@ py_test( srcs_version = "PY3", tags = [ "no_gpu", # b/127001953 - "no_oss_py2", ], deps = [ ":operators", @@ -96,7 +91,7 @@ py_test( name = "exceptions_test", srcs = ["exceptions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":operators", "//tensorflow/python:client_testlib", @@ -107,7 +102,7 @@ py_test( name = "logical_test", srcs = ["logical_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":operators", "//tensorflow/python:client_testlib", @@ -134,7 +129,7 @@ py_test( name = "slices_test", srcs = ["slices_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":operators", "//tensorflow/python:client_testlib", @@ -145,7 +140,7 @@ py_test( name = "variables_test", srcs = ["variables_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":operators", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index aaa4808cb0a..cb9c67e3000 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -67,7 +67,6 @@ import numpy as np from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.autograph.operators import variables from tensorflow.python.autograph.utils import ag_logging -from tensorflow.python.autograph.utils import compat_util from tensorflow.python.autograph.utils import misc from tensorflow.python.autograph.utils import tensors from tensorflow.python.data.experimental.ops import scan_ops @@ -490,24 +489,26 @@ def _known_len_tf_for_stmt( ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) iter_ = ta.unstack(iter_) - iterate_index = compat_util.BasicRef(0) + iterate_index = 0 def aug_get_state(): - return (iterate_index.value,) + get_state() + return (iterate_index,) + get_state() def aug_set_state(aug_loop_vars): - # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. - iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] + nonlocal iterate_index + # TODO(b/171479293): Drop the lint override. + iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): - body(iter_.read(iterate_index.value)) - iterate_index.value += 1 + nonlocal iterate_index + body(iter_.read(iterate_index)) + iterate_index += 1 def aug_test(): - main_test = iterate_index.value < n + main_test = iterate_index < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test @@ -536,24 +537,26 @@ def _tf_ragged_for_stmt( else: n = iter_.row_lengths()[0] - iterate_index = compat_util.BasicRef(0) + iterate_index = 0 def aug_get_state(): - return (iterate_index.value,) + get_state() + return (iterate_index,) + get_state() def aug_set_state(aug_loop_vars): - # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. - iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] + nonlocal iterate_index + # TODO(b/171479293): Drop the lint override. + iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): - body(iter_[iterate_index.value]) - iterate_index.value += 1 + nonlocal iterate_index + body(iter_[iterate_index]) + iterate_index += 1 def aug_test(): - main_test = iterate_index.value < n + main_test = iterate_index < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test @@ -574,7 +577,7 @@ def _tf_range_for_stmt( """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs - iterate = compat_util.BasicRef(start) + iterate = start def _value_or(name, var, default): if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): @@ -584,33 +587,35 @@ def _tf_range_for_stmt( def aug_get_state(): state_vars = get_state() state_vars = tuple( - _value_or(name, var, iterate.value) + _value_or(name, var, iterate) for name, var in zip(symbol_names, state_vars)) - return (iterate.value,) + state_vars + return (iterate,) + state_vars def aug_set_state(aug_loop_vars): - # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. - iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] + nonlocal iterate + # TODO(b/171479293): Drop the lint override. + iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): - body(iterate.value) - iterate.value += delta + nonlocal iterate + body(iterate) + iterate += delta def aug_test(): # TODO(b/159713842): Remove once constant folding works. const_delta = tensor_util.constant_value(delta) if const_delta is not None: if const_delta >= 0: - main_test = iterate.value < limit + main_test = iterate < limit else: - main_test = iterate.value > limit + main_test = iterate > limit else: main_test = math_ops.logical_or( - math_ops.logical_and(delta >= 0, iterate.value < limit), - math_ops.logical_and(delta < 0, iterate.value > limit)) + math_ops.logical_and(delta >= 0, iterate < limit), + math_ops.logical_and(delta < 0, iterate > limit)) if extra_test is not None: main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) @@ -633,14 +638,15 @@ def _tf_iterator_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" symbol_names = ('',) + symbol_names - has_next = compat_util.BasicRef(True) + has_next = True def aug_get_state(): - return (has_next.value,) + get_state() + return (has_next,) + get_state() def aug_set_state(aug_loop_vars): - # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax. - has_next.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:] + nonlocal has_next + # TODO(b/171479293): Drop the lint override. + has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable set_state(loop_vars) init_vars = aug_get_state() @@ -648,8 +654,9 @@ def _tf_iterator_for_stmt( def aug_body(): """Main body passed to _tf_while_stmt.""" + nonlocal has_next opt_iterate = iter_.get_next_as_optional() - has_next.value = opt_iterate.has_value() + has_next = opt_iterate.has_value() loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. def main_path(): @@ -669,13 +676,13 @@ def _tf_iterator_for_stmt( # Calling set_state so that get_state() _tf_while_loop sees the conditional # tensors. aug_set_state( - control_flow_ops.cond(has_next.value, main_path, noop_path)) + control_flow_ops.cond(has_next, main_path, noop_path)) def aug_test(): # This value takes a complicated path to get here: # prev_iteration_body -> get_state -> tf.while_loop (as loop var) - # -> current_iteration_body -> set_state -> has_next.value - main_test = has_next.value + # -> current_iteration_body -> set_state -> has_next + main_test = has_next if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test @@ -1216,6 +1223,3 @@ def _tf_if_stmt( def _py_if_stmt(cond, body, orelse): """Overload of if_stmt that executes a Python if statement.""" return body() if cond else orelse() - - -compat_util.deprecated_py2_support(__name__) diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index b35e51f3e7d..69ae22ad42f 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -429,8 +429,7 @@ def map_(fn, *iterables): def _tf_dataset_map(fn, *iterables): - zipped_dataset = dataset_ops.DatasetV2.zip(iterables) - return zipped_dataset.map(fn, num_parallel_calls=dataset_ops.AUTOTUNE) + return dataset_ops.DatasetV2.zip(iterables).map(fn) def _py_map(fn, *iterables): diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 8ff18fcf2f4..09f98682efa 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -29,7 +29,6 @@ py_library( "gast_util.py", "inspect_utils.py", "loader.py", - "loader_deprecated_py2.py", "naming.py", "origin_info.py", "parser.py", @@ -39,7 +38,7 @@ py_library( "transformer.py", "transpiler.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ "//tensorflow/python/autograph/pyct/common_transformers", @@ -54,7 +53,7 @@ py_test( name = "anno_test", srcs = ["anno_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -65,7 +64,7 @@ py_test( name = "ast_util_test", srcs = ["ast_util_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = [ "no_oss_py2", ], @@ -80,7 +79,7 @@ py_test( name = "cache_test", srcs = ["cache_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = [ "no_oss_py2", ], @@ -95,7 +94,7 @@ py_test( name = "cfg_test", srcs = ["cfg_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = [ "no_oss_py2", ], @@ -110,7 +109,7 @@ py_test( name = "loader_test", srcs = ["loader_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -122,7 +121,7 @@ py_test( name = "error_utils_test", srcs = ["error_utils_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -134,7 +133,7 @@ py_test( name = "inspect_utils_test", srcs = ["inspect_utils_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -153,7 +152,7 @@ py_test( name = "naming_test", srcs = ["naming_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -164,7 +163,7 @@ py_test( name = "origin_info_test", srcs = ["origin_info_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -176,7 +175,7 @@ py_test( name = "parser_test", srcs = ["parser_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -187,7 +186,7 @@ py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -198,7 +197,7 @@ py_test( name = "qual_names_test", srcs = ["qual_names_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -209,7 +208,7 @@ py_test( name = "templates_test", srcs = ["templates_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -222,7 +221,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -234,7 +233,7 @@ py_test( name = "transpiler_test", srcs = ["transpiler_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD index 61856a590ae..ecb50b2cece 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/BUILD +++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD @@ -21,7 +21,7 @@ py_library( srcs = [ "anf.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ "@gast_archive//:gast", @@ -33,7 +33,7 @@ py_test( name = "anf_test", srcs = ["anf_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = ["no_oss"], deps = [ "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/pyct/loader.py b/tensorflow/python/autograph/pyct/loader.py index 6eb925bca45..4fc01e35942 100644 --- a/tensorflow/python/autograph/pyct/loader.py +++ b/tensorflow/python/autograph/pyct/loader.py @@ -31,7 +31,6 @@ import tempfile from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser -from tensorflow.python.autograph.utils import compat_util def _remove_file(file_name): @@ -103,6 +102,3 @@ def load_ast(nodes, # TODO(mdan): Return a structured object. return module, source, source_map - - -compat_util.deprecated_py2_support(__name__) diff --git a/tensorflow/python/autograph/pyct/loader_deprecated_py2.py b/tensorflow/python/autograph/pyct/loader_deprecated_py2.py deleted file mode 100644 index fd962916cac..00000000000 --- a/tensorflow/python/autograph/pyct/loader_deprecated_py2.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Converting AST to code and Python entities. - -Python 2 compatibility version. Not maintained. - -Adapted from Tangent. -""" - -# TODO(mdan): Consolidate with parser and rename to parsing.py - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# TODO(mdan): Use six for compatibility here. -import atexit -import imp -import os -import tempfile - -import six - -from tensorflow.python.autograph.pyct import origin_info -from tensorflow.python.autograph.pyct import parser - - -def load_source(source, delete_on_exit): - """Loads the given source code as a Python module.""" - if six.PY2: - source = source.encode('utf-8') - f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) - else: - f = tempfile.NamedTemporaryFile( # pylint:disable=unexpected-keyword-arg - mode='w', suffix='.py', delete=False, encoding='utf-8') - - with f: - module_name = os.path.basename(f.name[:-3]) - f.write(source) - - if delete_on_exit: - atexit.register(lambda: os.remove(f.name)) - return imp.load_source(module_name, f.name), f.name - - -def load_ast(nodes, - indentation=' ', - include_source_map=False, - delete_on_exit=True): - """Loads the given AST as a Python module. - - Compiling the AST code this way ensures that the source code is readable by - e.g. `pdb` or `inspect`. - - Args: - nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST - object. - indentation: Text, the string to use for indentation. - include_source_map: bool, whether return a source map. - delete_on_exit: bool, whether to delete the temporary file used for - compilation on exit. - - Returns: - Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing: - the module containing the unparsed nodes, the source code corresponding to - nodes, and the source map. Is include_source_map is False, the source map - will be None. - """ - if not isinstance(nodes, (list, tuple)): - nodes = (nodes,) - - source = parser.unparse(nodes, indentation=indentation) - module, _ = load_source(source, delete_on_exit) - - if include_source_map: - source_map = origin_info.create_source_map(nodes, source, module.__file__) - else: - source_map = None - - # TODO(mdan): Return a structured object. - return module, source, source_map diff --git a/tensorflow/python/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD index 1eaf3b3c177..0f05cb58d9e 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD @@ -26,7 +26,7 @@ py_library( "reaching_fndefs.py", "type_inference.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ "//tensorflow/python:util", @@ -40,7 +40,7 @@ py_test( name = "activity_test", srcs = ["activity_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":static_analysis", "//tensorflow/python:client_testlib", @@ -53,7 +53,7 @@ py_library( name = "activity_test_lib", testonly = True, srcs = ["activity_test.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":static_analysis", "//tensorflow/python:client_testlib", @@ -89,38 +89,8 @@ py_test( testonly = True, srcs = ["liveness_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":static_analysis", - "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/pyct", - ], -) - -py_library( - name = "liveness_test_lib", - srcs = ["liveness_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":static_analysis", - "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/pyct", - "@gast_archive//:gast", - ], -) - -py_test( - name = "liveness_py3_test", - srcs = ["liveness_py3_test.py"], - python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - "no_pip", - "nopip", - ], deps = [ - ":liveness_test_lib", ":static_analysis", "//tensorflow/python:client_testlib", "//tensorflow/python/autograph/pyct", @@ -131,38 +101,8 @@ py_test( name = "reaching_definitions_test", srcs = ["reaching_definitions_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":static_analysis", - "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/pyct", - ], -) - -py_library( - name = "reaching_definitions_test_lib", - srcs = ["reaching_definitions_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":static_analysis", - "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/pyct", - "@gast_archive//:gast", - ], -) - -py_test( - name = "reaching_definitions_py3_test", - srcs = ["reaching_definitions_py3_test.py"], - python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - "no_pip", - "nopip", - ], deps = [ - ":reaching_definitions_test_lib", ":static_analysis", "//tensorflow/python:client_testlib", "//tensorflow/python/autograph/pyct", diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_py3_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_py3_test.py index 035d3f7458b..6e5e69602fd 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_py3_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_py3_test.py @@ -30,25 +30,6 @@ NodeAnno = annos.NodeAnno class LivenessAnalyzerTest(liveness_test.LivenessAnalyzerTestBase): """Tests which can only run in Python 3.""" - def test_nonlocal_symbol(self): - - nonlocal_a = 3 - nonlocal_b = 13 - - def test_fn(c): - nonlocal nonlocal_a - nonlocal nonlocal_b - if nonlocal_a: - nonlocal_b = c - else: - nonlocal_b = c - return nonlocal_b - - node = self._parse_and_analyze(test_fn) - fn_body = node.body - self.assertHasLiveOut(fn_body[2], ('nonlocal_b',)) - self.assertHasLiveIn(fn_body[2], ('nonlocal_a', 'c')) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index ecb466532e2..1aba627cd3b 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -552,6 +552,25 @@ class LivenessAnalyzerTest(LivenessAnalyzerTestBase): self.assertHasLiveOut(fn_body[2], ('global_b',)) self.assertHasLiveIn(fn_body[2], ('global_a', 'c')) + def test_nonlocal_symbol(self): + + nonlocal_a = 3 + nonlocal_b = 13 + + def test_fn(c): + nonlocal nonlocal_a + nonlocal nonlocal_b + if nonlocal_a: + nonlocal_b = c + else: + nonlocal_b = c + return nonlocal_b + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + self.assertHasLiveOut(fn_body[2], ('nonlocal_b',)) + self.assertHasLiveIn(fn_body[2], ('nonlocal_a', 'c')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index ac91b662a47..59feadb45d9 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -465,6 +465,71 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase): self.assertHasDefinedIn(fn_body[2], ('global_a', 'global_b')) + def test_nonlocal(self): + + a = 3 + b = 13 + + def test_fn(): + nonlocal a + nonlocal b + if a: + b = [] + return a, b + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + + self.assertHasDefs(fn_body[2].test, 1) + self.assertHasDefs(fn_body[2].body[0].targets[0], 1) + self.assertHasDefs(fn_body[3].value.elts[0], 1) + self.assertHasDefs(fn_body[3].value.elts[1], 2) + + self.assertSameDef(fn_body[2].test, fn_body[3].value.elts[0]) + + self.assertHasDefinedIn(fn_body[2], ('a', 'b')) + + def test_nonlocal_in_nested_function(self): + + a = 3 + b = 13 + + def test_fn(): + a = 3 + b = 13 + + def local_fn(): + nonlocal a, b + if a: + b = [] + return a, b + + return local_fn() + + node = self._parse_and_analyze(test_fn) + local_body = node.body[2].body + + self.assertHasDefs(local_body[1].test, 1) + self.assertHasDefs(local_body[1].body[0].targets[0], 1) + self.assertHasDefs(local_body[2].value.elts[0], 1) + self.assertHasDefs(local_body[2].value.elts[1], 2) + + self.assertSameDef(local_body[1].test, local_body[2].value.elts[0]) + + # Note: the function name is is visible inside the function body. But it's + # a closure variable, not a local. + # + # Example: + # + # >>> def f(): + # ... print(f) + # >>> g = f + # >>> f = 'something else' + # >>> g() + # something else + # + self.assertHasDefinedIn(local_body[1], ('a', 'b')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD index 59b15ceaf05..811ee6faae1 100644 --- a/tensorflow/python/autograph/pyct/testing/BUILD +++ b/tensorflow/python/autograph/pyct/testing/BUILD @@ -22,7 +22,7 @@ py_library( "basic_definitions.py", "decorators.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], ) @@ -31,7 +31,7 @@ py_library( srcs = [ "codegen.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ "//tensorflow/python/autograph/pyct", @@ -45,7 +45,7 @@ py_test( size = "large", srcs = ["codegen_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = [ "manual", "no_windows", diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD index ba44b2b435c..5bb561014c3 100644 --- a/tensorflow/python/autograph/utils/BUILD +++ b/tensorflow/python/autograph/utils/BUILD @@ -29,7 +29,7 @@ py_library( "tensors.py", "testing.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/python:dtypes", @@ -47,7 +47,7 @@ py_test( name = "context_managers_test", srcs = ["context_managers_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -58,7 +58,7 @@ py_test( name = "misc_test", srcs = ["misc_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -69,7 +69,7 @@ py_test( name = "py_func_test", srcs = ["py_func_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", tags = ["no_windows"], deps = [ ":utils", @@ -81,7 +81,7 @@ py_test( name = "tensor_list_test", srcs = ["tensor_list_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -93,7 +93,7 @@ py_test( name = "tensors_test", srcs = ["tensors_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 55f0debadcb..306381347c7 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -1172,6 +1172,13 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return "TensorHandle"; }); + m.def("TF_RegisterFilesystemPlugin", [](const char* plugin_filename) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TF_RegisterFilesystemPlugin(plugin_filename, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); + py::enum_(m, "TF_DataType") .value("TF_FLOAT", TF_FLOAT) .value("TF_DOUBLE", TF_DOUBLE) diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index ddc8c484390..f7c282750be 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -32,7 +32,6 @@ tf_py_test( size = "small", srcs = ["compat_test.py"], tags = ["nofwdcompat"], - tfrt_enabled = True, deps = [ ":compat", "//tensorflow/python:client_testlib", @@ -43,7 +42,6 @@ tf_py_test( name = "disable_v2_behavior_test", size = "small", srcs = ["disable_v2_behavior_test.py"], - tfrt_enabled = True, deps = [ ":v2_compat", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 0c700718285..ad63ab13a58 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 13) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 22) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index fcbb5c0bd87..6b3f32cadc4 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -101,7 +101,6 @@ cuda_py_test( "no_windows", "nomac", ], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, deps = [ ":trt_convert_py", @@ -124,6 +123,7 @@ cuda_py_test( cuda_py_tests( name = "tf_trt_integration_test", srcs = [ + "test/annotate_max_batch_sizes_test.py", "test/base_test.py", "test/batch_matmul_test.py", "test/biasadd_matmul_test.py", @@ -182,7 +182,6 @@ cuda_py_test( "no_windows", "nomac", ], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, deps = [ ":tf_trt_integration_test_base", diff --git a/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py new file mode 100644 index 00000000000..7eadb001708 --- /dev/null +++ b/tensorflow/python/compiler/tensorrt/test/annotate_max_batch_sizes_test.py @@ -0,0 +1,147 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing the impact of graph node _tftrt_op_max_batch_size annotation on TRTEngineOp attributes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MaxBatchSizesTestBase(trt_test.TfTrtIntegrationTestBase): + + @classmethod + def setUpClass(cls): + if cls is MaxBatchSizesTestBase: + raise unittest.SkipTest( + 'MaxBatchSizesTestBase defines base class for other tests.') + super(MaxBatchSizesTestBase, cls).setUpClass() + + @property + def tensor_shapes(self): + return [[1, 512, 1, 1], [64, 2, 2, 2], [32, 4, 2, 2], [16, 8, 2, 2]] + + @property + def max_batch_sizes(self): + return [shape[0] for shape in self.tensor_shapes] + + def GetParams(self): + """Gets the build parameters for the test.""" + return self.BuildParams( + self.GraphFn, + dtype=dtypes.float32, + input_shapes=[self.tensor_shapes[0]], + output_shapes=[self.tensor_shapes[-1]]) + + def ShouldRunTest(self, run_params): + # The maximum batch size for dynamic engines will be the actual batch size + # detected at runtime. Therefore, we don't run the test with dynamic + # engines. + return (not run_params.dynamic_engine, 'test static engine only.') + + def GetConversionParams(self, run_params): + """Returns a ConversionParams for test.""" + conversion_params = super(MaxBatchSizesTestBase, + self).GetConversionParams(run_params) + conversion_params._replace( + max_batch_size=min(self.max_batch_sizes), maximum_cached_engines=1) + rewrite_config_with_trt = self.GetTrtRewriterConfig( + run_params=run_params, + conversion_params=conversion_params, + use_implicit_batch=True, + disable_non_trt_optimizers=True) + return conversion_params._replace( + rewriter_config_template=rewrite_config_with_trt) + + def ExpectedEnginesToBuild(self, run_params): + """Checks that the expected engine is built. + + Args: + run_params: the run parameters. + + Returns: + the expected engines to build. + + There shall be engines generated for each maximum batch size. + """ + return [ + 'TRTEngineOp_{}'.format(seq_id) + for seq_id in range(len(self.max_batch_sizes)) + ] + + def ExpectedMaxBatchSizes(self, run_params): + """Checks that the expected maximum batch sizes for the generated engines. + + Args: + run_params: the run parameters. + + Returns: + the expected maximum batch sizes for the generated engines. + + There shall be engines generated for each maximum batch size. + """ + return self.max_batch_sizes + + +class AnnotateMaxBatchSizesTest(MaxBatchSizesTestBase): + + def GraphFn(self, inp): + """Builds a tf.Graph for the test.""" + tensor = inp * 2.0 + tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[1][1:]) + with ops.get_default_graph()._attr_scope({ + '_tftrt_op_max_batch_size': + attr_value_pb2.AttrValue(i=self.max_batch_sizes[1]) + }): + tensor = tensor + 3.0 + tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[2][1:]) + with ops.get_default_graph()._attr_scope({ + '_tftrt_op_max_batch_size': + attr_value_pb2.AttrValue(i=self.max_batch_sizes[2]) + }): + tensor = tensor * 4.0 + tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[3][1:]) + with ops.get_default_graph()._attr_scope({ + '_tftrt_op_max_batch_size': + attr_value_pb2.AttrValue(i=self.max_batch_sizes[3]) + }): + tensor += tensor + 5.0 + return array_ops.identity(tensor, name='output_0') + + +class StaticBatchSizeTest(MaxBatchSizesTestBase): + + def GraphFn(self, inp): + """Builds a tf.Graph for the test.""" + tensor = inp * 2.0 + tensor = array_ops.reshape(tensor, self.tensor_shapes[1]) + tensor = tensor + 3.0 + tensor = array_ops.reshape(tensor, self.tensor_shapes[2]) + tensor = tensor * 4.0 + tensor = array_ops.reshape(tensor, self.tensor_shapes[3]) + tensor += tensor + 5.0 + return array_ops.identity(tensor, name='output_0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py index e956ce58814..c32ea99629b 100644 --- a/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest + import numpy as np from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test @@ -29,7 +31,30 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class BatchMatMulTwoTensorTest(trt_test.TfTrtIntegrationTestBase): +class BatchMatMultTestBase(trt_test.TfTrtIntegrationTestBase): + """Base class for BatchMatMult tests.""" + + # Shape inference of BatchMatMultV2 doesn't work. Use static batch size. + def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes): + return self.BuildParamsWithMask( + graph_fn=graph_fn, + dtype=dtype, + input_shapes=input_shapes, + output_shapes=output_shapes, + input_mask=[[True] * len(s) for s in input_shapes], + output_mask=[[True] * len(s) for s in output_shapes], + extra_inputs=[], + extra_outputs=[]) + + @classmethod + def setUpClass(cls): + if cls is BatchMatMultTestBase: + raise unittest.SkipTest( + "BatchMatMultTestBase defines base class for other test.") + super(BatchMatMultTestBase, cls).setUpClass() + + +class BatchMatMulTwoTensorTest(BatchMatMultTestBase): """Testing conversion of BatchMatMul where both inputs are tensors.""" def GraphFn(self, inp, inp1): @@ -47,7 +72,7 @@ class BatchMatMulTwoTensorTest(trt_test.TfTrtIntegrationTestBase): return {"TRTEngineOp_0": ["matmul", "relu"]} -class BatchMatMulWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): +class BatchMatMulWeightBroadcastTest(BatchMatMultTestBase): """Testing BatchMatMulV2: one operand is weight and both have same rank.""" def GraphFn(self, inp): @@ -66,7 +91,7 @@ class BatchMatMulWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): return {"TRTEngineOp_0": ["matmul", "kernel"]} -class BatchMatMulWeightBroadcastDims2Test(trt_test.TfTrtIntegrationTestBase): +class BatchMatMulWeightBroadcastDims2Test(BatchMatMultTestBase): """Testing BatchMatMulV2: weight operand must be broadcasted.""" def GraphFn(self, inp): diff --git a/tensorflow/python/compiler/tensorrt/test/rank_two_test.py b/tensorflow/python/compiler/tensorrt/test/rank_two_test.py index b23e052a316..8884cd702d2 100644 --- a/tensorflow/python/compiler/tensorrt/test/rank_two_test.py +++ b/tensorflow/python/compiler/tensorrt/test/rank_two_test.py @@ -67,9 +67,14 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): "abs0_2", "expand0_0", "expand0_1", "axis" ], "TRTEngineOp_1": [ - "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", - "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" + "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", "abs1_1", + "abs1_2", "reciprocal1" ], + # The two ops can't be in the same cluster as the ops in TRTEngineOp_0 + # due to trt_incompatible_op. They can't be in the same cluster as the + # ops in TRTEngineOP_1 because their batch size belongs to a different + # equivalent class. + "TRTEngineOp_2": ["add", "reciprocal0"] } diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 76aa755391f..2265a19cf62 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple +import collections import errno import gc import itertools @@ -57,7 +57,7 @@ from tensorflow.python.tools import saved_model_utils from tensorflow.python.training.tracking import tracking from tensorflow.python.util import nest -TfTrtIntegrationTestParams = namedtuple( +TfTrtIntegrationTestParams = collections.namedtuple( "TfTrtIntegrationTestParams", [ # A function that creates the TF graph for testing. @@ -74,7 +74,7 @@ TfTrtIntegrationTestParams = namedtuple( "expected_output_dims" ]) -RunParams = namedtuple( +RunParams = collections.namedtuple( "RunParams", [ # Whether to run the conversion online with RewriterConfig, or offline @@ -305,9 +305,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): run_params.precision_mode)), "test either calibration or non-INT8" def ExpectedEnginesToBuild(self, run_params): - """Return the expected engines to build, implemented by subclass.""" + """Returns the expected engines to build, implemented by subclass.""" raise NotImplementedError() + def ExpectedMaxBatchSizes(self, run_params): + """Returns the expected maximum batch sizes of the build engines.""" + return None + def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 @@ -537,18 +541,39 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Writing graph to %s/%s", temp_dir, graph_name) graph_io.write_graph(gdef, temp_dir, graph_name) - # Remove the graph sequence number prefix from the name only if the name has - # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a - # prefix exists. - def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix): - match = re.search(r"TRTEngineOp_\d+_", name) - has_prefix = match and name.startswith(match.group(0)) - assert (not expecting_prefix) or has_prefix - if has_prefix: - parts = name.split("_", maxsplit=2) - assert len(parts) == 3 - return parts[0] + "_" + parts[2] - return name + # Removes the prefix(s) of function name(s). + # The input value can be a string or a sequence of string. + def _Canonicalize(self, value): + if isinstance(value, str): + return self._ToString(value.split("/")[-1]) + elif isinstance(value, collections.abc.Iterable): + return set(self._Canonicalize(nm) for nm in value) + else: + raise TypeError( + "'_Canonicalize' can only be used on strings or sequence of strings!") + + # Removes the graph sequence number prefix from the name(s) only if the + # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts + # such a prefix exists. + # The input value can be a string or a sequence of string. + def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix): + if isinstance(value, str): + match = re.search(r"TRTEngineOp_\d+_", value) + has_prefix = match and value.startswith(match.group(0)) + assert (not expecting_prefix) or has_prefix + if has_prefix: + parts = value.split("_", maxsplit=2) + assert len(parts) == 3 + return parts[0] + "_" + parts[2] + return value + elif isinstance(value, collections.abc.Iterable): + return set( + self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix) + for nm in value) + else: + raise TypeError( + "'_RemoveGraphSequenceNumberImpl' can only be used on strings " + "or sequence of strings!") def _RemoveGraphSequenceNumber(self, name): return self._RemoveGraphSequenceNumberImpl(name, True) @@ -644,6 +669,124 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): msg="\nexpected:\n%s\nvs actual:\n%s" % (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) + def _VerifyMaxBatchSizeAnnotations( + self, + expected_engines, + original_gdef, + converted_gdef, + default_max_batch_size, + expected_max_batch_sizes=None, + ): + """Verifies the max batch size annotations in the original and converted GraphDef. + + Args: + expected_engines: A sequence of engines names. + original_gdef: GraphDef. The graph def before TensorRT conversion. + converted_gdef: GraphDef. The graph def after TensorRT conversion. + default_max_batch_size: The default maximum batch size to use if no node + inside a segment is annoted with a customized max batch size. + expected_max_batch_sizes: Optional. A sequence of max batch sizes for all + the engines. `None` if does not check enforce max batch sizes. + """ + if isinstance(expected_max_batch_sizes, collections.abc.Collection): + self.assertEqual(len(expected_max_batch_sizes), len(expected_engines)) + else: + self.assertIsNone( + expected_max_batch_sizes, + "'expected_max_batch_sizes' shall only be a sequence " + "of integers or `None`.") + + def _ChainAllNodes(graph_def): + return itertools.chain( + graph_def.node, + itertools.chain( + *[func.node_def for func in graph_def.library.function])) + + old_name_to_node_map = { + self._ToString(node.name): node + for node in _ChainAllNodes(original_gdef) + } + new_name_to_func_map = { + self._ToString(func.signature.name): func + for func in converted_gdef.library.function + } + + def _DetectStaticBatchSize(node_def): + """Returns the static batch size of an operation or None. + + It is incorrect to use the output shapes to find the batch size of an + operation, as the segmenter actually uses the input shapes. However, it is + a simplication and works for most of the cases for the test purposes. + + Args: + node_def: `tf.NodeDef`. The target node for analysis. + + Returns: + If all the outputs of the node have the same static batch size, returns + the int value for the batch size. Otherwise returns None. + """ + shapes = node_def.attr["_output_shapes"].list.shape + batch_size = set( + list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes) + if len(batch_size) == 1 and list(batch_size)[0] >= 1: + return list(batch_size)[0] + return None + + name_to_engines_map = {} + actual_max_batch_sizes = [] + for node in _ChainAllNodes(converted_gdef): + if node.op == "TRTEngineOp": + engine = node + engine_name = self._RemoveGraphSequenceNumber( + self._Canonicalize(self._ToString(engine.name))) + self.assertIn(engine_name, expected_engines) + name_to_engines_map[engine_name] = engine + # The input nodes shall not have the conflicting annotation (no + # annotation or the same annotation) with the maximum batch size + # annotation. If the engine has maximum batch size annotation as the + # non-default maximum batch size, then at least one input node shall + # have the same annotation to be the source. + self.assertIn("max_batch_size", node.attr) + engine_max_batch_size = node.attr["max_batch_size"].i + self.assertIsInstance(engine_max_batch_size, int) + actual_max_batch_sizes.append(engine_max_batch_size) + seg_func = node.attr["segment_func"].func + self.assertIsNotNone(seg_func) + self.assertIn(seg_func.name, new_name_to_func_map) + seg_func_def = new_name_to_func_map[seg_func.name] + logging.info("Segment function name: %s. Including %d nodes.", + seg_func.name, len(seg_func_def.node_def)) + node_max_batch_size_all_none = True + # Use the native segment to search for replaced nodes + for alternative_node in seg_func_def.node_def: + node_name = self._Canonicalize(self._ToString(alternative_node.name)) + if node_name not in old_name_to_node_map: + continue + original_node = old_name_to_node_map[node_name] + node_max_batch_size = None + if "_tftrt_op_max_batch_size" in original_node.attr: + node_max_batch_size = original_node.attr[ + "_tftrt_op_max_batch_size"].i + elif (original_node.op != "Const" and + alternative_node.op != "Const" and + "_output_shapes" in original_node.attr): + node_max_batch_size = _DetectStaticBatchSize(original_node) + logging.info( + "'{%s}(%s)'s max batch size annotation is %s. " + "'{%s}'s max batch size is %s.", node_name, original_node.op, + str(node_max_batch_size), engine_name, str(engine_max_batch_size)) + node_max_batch_size_all_none &= node_max_batch_size is None + self.assertTrue(engine_max_batch_size == node_max_batch_size or + node_max_batch_size is None) + logging.info("'{%s}'s max batch size is %d.", engine_name, + engine_max_batch_size) + self.assertTrue(engine_max_batch_size == default_max_batch_size or + not node_max_batch_size_all_none) + + self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys())) + if expected_max_batch_sizes is not None: + self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes) + def _GetGraphDef(self, run_params, gdef_or_saved_model_dir): if isinstance(gdef_or_saved_model_dir, str): if run_params.is_v2: @@ -703,7 +846,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertEqual(num_engines, len(expected_engines)) if isinstance(expected_engines, dict): self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify) - # TODO(aaroey): consider verifying the corresponding TF function. + self._VerifyMaxBatchSizeAnnotations( + expected_engines=expected_engines, + original_gdef=original_gdef, + converted_gdef=gdef_to_verify, + expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), + default_max_batch_size=self.GetConversionParams( + run_params).max_batch_size, + ) def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, graph_state): @@ -721,15 +871,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): all_op_names.append(node.name) if node.op == "TRTEngineOp": trt_op_names.append(node.name) - # Remove the function name prefix. - def _Canonicalize(names): - return set(self._ToString(name.split("/")[-1]) for name in names) - # Remove the graph sequence number prefix from all the names. - def _RemoveGraphSequenceNumber(names): - return set(self._RemoveGraphSequenceNumber(name) for name in names) - all_op_names = _Canonicalize(all_op_names) - trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names)) + all_op_names = self._Canonicalize(all_op_names) + trt_op_names = self._RemoveGraphSequenceNumber( + self._Canonicalize(trt_op_names)) if isinstance(expected_engines, dict): # For simplicity we don't verify the connections inside the engine in @@ -741,6 +886,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): expected_engines = set(expected_engines.keys()) self.assertEqual(set(expected_engines), trt_op_names) + self._VerifyMaxBatchSizeAnnotations( + expected_engines=expected_engines, + original_gdef=original_gdef, + converted_gdef=gdef_to_verify, + expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), + default_max_batch_size=self.GetConversionParams( + run_params).max_batch_size, + ) def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, gdef_or_saved_model_dir_to_verify, graph_state): diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py index ee56be3f11a..5c7f261fa98 100644 --- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py +++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py @@ -125,7 +125,8 @@ class ExplicitBatchTest(TrtModeTestBase): def GetConversionParams(self, run_params): """Return a TrtConversionParams for test that enables explicit batch.""" - return super(ExplicitBatchTest, self).GetConversionParams(run_params, False) + return super(ExplicitBatchTest, self).GetConversionParams( + run_params, implicit_batch=False) def ExpectedEnginesToBuild(self, run_params): """Check that the expected engine is built. diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 29cb6d14856..8ea5c96f4cc 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -50,6 +50,7 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver from tensorflow.python.training.tracking import tracking +from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export @@ -426,6 +427,8 @@ class TrtGraphConverter(object): ``` """ + @deprecation.deprecated_args(None, "Remove the use of this argument", + "session_config") def __init__(self, input_saved_model_dir=None, input_saved_model_tags=None, @@ -994,6 +997,9 @@ class TrtGraphConverterV2(object): assert context.executing_eagerly() if conversion_params is None: conversion_params = TrtConversionParams() + elif conversion_params.rewriter_config_template is not None: + tf_logging.warn("the rewrite_config_template field will be deprecated.") + _check_trt_version_compatibility() _check_conversion_params(conversion_params, is_v2=True) @@ -1052,6 +1058,10 @@ class TrtGraphConverterV2(object): [tensor.name for tensor in func.outputs]) rebuilt_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, rebuilt_func.graph.structured_outputs) + # Copy structured input signature from original function (used during + # serialization) + rebuilt_func.graph.structured_input_signature = ( + func.structured_input_signature) return rebuilt_func # TODO(laigd): provide a utility function to optimize a ConcreteFunction and @@ -1104,6 +1114,10 @@ class TrtGraphConverterV2(object): self._converted_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, self._converted_func.graph.structured_outputs) + # Copy structured input signature from original function (used during + # serialization) + self._converted_func.graph.structured_input_signature = ( + func.structured_input_signature) if self._need_calibration: for inp in calibration_input_fn(): @@ -1259,6 +1273,8 @@ class TrtGraphConverterV2(object): reset_converted_func.graph.structured_outputs = nest.pack_sequence_as( self._converted_func.graph.structured_outputs, reset_converted_func.graph.structured_outputs) + reset_converted_func.graph.strucutred_input_signature = ( + self._converted_func.structured_input_signature) self._converted_func = reset_converted_func signatures[self._input_saved_model_signature_key] = self._converted_func diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index 5d92e43c61a..79c18571f9a 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -24,6 +24,9 @@ cuda_py_test( name = "jit_test", size = "small", srcs = ["jit_test.py"], + tags = [ + "no_windows", # TODO(b/171385770) + ], xla_enabled = True, deps = [ ":compiler_py", @@ -99,7 +102,6 @@ cuda_py_test( "no_mac", "no_windows", ], - tfrt_enabled = True, xla_enabled = True, deps = [ "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/benchmarks/BUILD b/tensorflow/python/data/benchmarks/BUILD index 94c189c43f1..3f0faf5364a 100644 --- a/tensorflow/python/data/benchmarks/BUILD +++ b/tensorflow/python/data/benchmarks/BUILD @@ -10,7 +10,6 @@ exports_files(["LICENSE"]) tf_py_test( name = "meta_benchmark", srcs = ["meta_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:session", @@ -35,7 +34,6 @@ py_library( tf_py_test( name = "batch_benchmark", srcs = ["batch_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python:sparse_tensor", @@ -47,7 +45,6 @@ tf_py_test( tf_py_test( name = "filter_benchmark", srcs = ["filter_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python/data/ops:dataset_ops", @@ -57,7 +54,6 @@ tf_py_test( tf_py_test( name = "from_tensor_slices_benchmark", srcs = ["from_tensor_slices_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python/data/experimental/ops:get_single_element", @@ -69,7 +65,6 @@ tf_py_test( tf_py_test( name = "list_files_benchmark", srcs = ["list_files_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python:client_testlib", @@ -84,7 +79,6 @@ tf_py_test( tf_py_test( name = "map_benchmark", srcs = ["map_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python/data/ops:dataset_ops", @@ -94,7 +88,6 @@ tf_py_test( tf_py_test( name = "prefetch_benchmark", srcs = ["prefetch_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python/data/ops:dataset_ops", @@ -104,7 +97,6 @@ tf_py_test( tf_py_test( name = "range_benchmark", srcs = ["range_benchmark.py"], - tfrt_enabled = True, deps = [ ":benchmark_base", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD index a3ceb9ed37f..e3ca2d52ab5 100644 --- a/tensorflow/python/data/experimental/benchmarks/BUILD +++ b/tensorflow/python/data/experimental/benchmarks/BUILD @@ -25,7 +25,6 @@ py_binary( tf_py_test( name = "autotune_benchmark", srcs = ["autotune_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", @@ -38,7 +37,6 @@ tf_py_test( tf_py_test( name = "choose_fastest_benchmark", srcs = ["choose_fastest_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -52,7 +50,6 @@ tf_py_test( tf_py_test( name = "choose_fastest_branch_benchmark", srcs = ["choose_fastest_branch_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -69,7 +66,6 @@ tf_py_test( name = "csv_dataset_benchmark", srcs = ["csv_dataset_benchmark.py"], tags = ["no_pip"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:parsing_ops", @@ -85,7 +81,6 @@ tf_py_test( tf_py_test( name = "map_and_batch_benchmark", srcs = ["map_and_batch_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -104,7 +99,6 @@ tf_py_test( tf_py_test( name = "map_defun_benchmark", srcs = ["map_defun_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -120,7 +114,6 @@ tf_py_test( tf_py_test( name = "map_vectorization_benchmark", srcs = ["map_vectorization_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -140,7 +133,6 @@ tf_py_test( name = "matching_files_benchmark", size = "small", srcs = ["matching_files_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -156,7 +148,6 @@ tf_py_test( tf_py_test( name = "optimize_benchmark", srcs = ["optimize_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -170,7 +161,6 @@ tf_py_test( tf_py_test( name = "parallel_interleave_benchmark", srcs = ["parallel_interleave_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", @@ -186,7 +176,6 @@ tf_py_test( name = "rejection_resample_benchmark", srcs = ["rejection_resample_benchmark.py"], tags = ["no_pip"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/data/experimental/ops:resampling", @@ -199,7 +188,6 @@ tf_py_test( tf_py_test( name = "snapshot_dataset_benchmark", srcs = ["snapshot_dataset_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -218,7 +206,6 @@ tf_py_test( tf_py_test( name = "unbatch_benchmark", srcs = ["unbatch_benchmark.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 0ea339d6301..b1510a56c7e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -157,7 +157,6 @@ tf_py_test( srcs = ["data_service_ops_test.py"], shard_count = 10, srcs_version = "PY3", - tags = ["notap"], # TODO(b/170783829) deps = [ ":data_service_test_base", "//tensorflow:tensorflow_py", @@ -380,9 +379,6 @@ tf_py_test( name = "map_and_batch_test", size = "medium", srcs = ["map_and_batch_test.py"], - tags = [ - "nomsan", # b/168906619 - ], deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -406,7 +402,6 @@ tf_py_test( size = "small", srcs = ["map_defun_op_test.py"], tags = ["no_pip"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", @@ -744,7 +739,6 @@ tf_py_test( tf_py_test( name = "sleep_test", srcs = ["sleep_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:util", @@ -811,6 +805,7 @@ tf_py_test( size = "small", srcs = ["stats_dataset_ops_test.py"], tags = [ + "no_oss", # TODO(b/155795733): Note that this functionality is deprecated. "no_pip", "notap", ], diff --git a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py index 564dda0cf11..d428baca9c0 100644 --- a/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/auto_shard_dataset_test.py @@ -103,6 +103,43 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, ] self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(batch_size=[1, 3, 10]))) + def testDatasetOfReaderDatasetsPipeline(self, batch_size): + # This tests a scenario where a list_files main return multiple files + # due to the glob containing wildcards. + def batch(iterator, n): + l = len(iterator) + for i in range(0, l, n): + yield iterator[i:min(i + n, l)] + + datasets = [] + for files in batch(self.test_filenames, batch_size): + datasets.append( + dataset_ops.Dataset.list_files(files, shuffle=False).map( + core_readers.TFRecordDataset)) + dataset = dataset_ops.Dataset.from_tensor_slices(datasets) + dataset = dataset.flat_map(lambda x: x) + + # Simulate additional ops in between flat_map and interleave. This should be + # a no-op since if ShardDataset is placed right after flat_map, we will only + # have two datasets left at this point. + dataset = dataset.prefetch(1) + dataset = dataset.prefetch(1) + + dataset = dataset.interleave( + lambda x: x, cycle_length=1, num_parallel_calls=1) + + dataset = distribute._AutoShardDataset(dataset, 5, 0) + expected = [ + b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension + for f in (0, 5) + for r in range(0, 10) + ] + + self.assertDatasetProduces(dataset, expected) + @combinations.generate(test_base.default_test_combinations()) def testZipReaderPipeline(self): dataset1 = dataset_ops.Dataset.list_files( diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py index 99d1a7b563e..33b7ca25985 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py @@ -219,6 +219,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testChangeProcessingModeAfterRestart(self): + self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=1) num_elements = 100 range_dataset = dataset_ops.Dataset.range(num_elements) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index 65674818daf..ddd301d1540 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -214,7 +214,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): cluster = self.create_cluster(num_workers=1) - num_elements = 100 + num_elements = 1000 def make_ds(): return dataset_ops.Dataset.range(num_elements).shuffle(num_elements) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py index 0bb1383a56b..0e48e1f4dd9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py @@ -99,7 +99,8 @@ class TestCluster(object): server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), - heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS), + heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, + dispatcher_timeout_ms=1000), start=start)) def start_dispatcher(self): diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 46ae19a05a9..e3d8c60d317 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") package( default_visibility = ["//tensorflow:internal"], @@ -255,6 +255,7 @@ tf_py_test( "no_oss", "no_pip", "no_windows", + "noasan", # TODO(b/337374867) fails with -fsanitize=null ], deps = [ "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 8175480182f..941ce327555 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -287,6 +287,40 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) + def testEmptyFirstSplits(self, drop_remainder): + dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) + rebatched_dataset = distribute._RebatchDataset( + dataset, batch_sizes=[0, 1], drop_remainder=drop_remainder) + + expected_shapes = [[None]] + self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) + + # We have an extra element at the end because if the desired batch size is + # zero, then we never read any inputs from the input_dataset at all, so we + # will keep producting empty outputs until we reach a non zero desired batch + # size split. + expected_output = [[], [0], [], [1], [], [2], [], [3], + [], [4], [], [5], [], [6], [], [7], []] + self.assertDatasetProduces(rebatched_dataset, expected_output) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) + def testEmptyLastSplits(self, drop_remainder): + dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) + rebatched_dataset = distribute._RebatchDataset( + dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder) + + expected_shapes = [[None]] + self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) + + expected_output = [[0], [], [1], [], [2], [], [3], [], + [4], [], [5], [], [6], [], [7], []] + self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False]))) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD index da710783151..9aa9f3f5447 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -40,7 +40,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -59,7 +58,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -82,7 +80,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -103,7 +100,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -122,7 +118,6 @@ tf_py_test( "no_windows", "notsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -166,7 +161,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -184,7 +178,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -203,7 +196,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -222,7 +214,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -240,7 +231,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -259,7 +249,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -280,7 +269,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -298,7 +286,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -324,7 +311,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -343,7 +329,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -361,8 +346,8 @@ tf_py_test( "no_oss", "no_pip", "no_windows", + "noasan", # TODO(b/337374867) fails with -fsanitize=null ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -382,7 +367,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -403,7 +387,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -422,7 +405,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -446,7 +428,6 @@ tf_py_test( tags = [ "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -465,7 +446,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -483,7 +463,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -501,7 +480,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -521,7 +499,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -542,7 +519,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -567,7 +543,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -584,7 +559,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -601,7 +575,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -625,7 +598,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -643,7 +615,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -662,7 +633,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -680,7 +650,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_combinations", @@ -701,7 +670,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -719,7 +687,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -737,7 +704,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -757,7 +723,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -777,7 +742,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:array_ops", @@ -798,7 +762,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -818,7 +781,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -837,7 +799,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -855,7 +816,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -874,7 +834,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", @@ -892,7 +851,6 @@ tf_py_test( "no_pip", "no_windows", ], - tfrt_enabled = True, deps = [ ":dataset_serialization_test_base", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index c46aeb116c3..b1fa780f6b3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -325,10 +325,12 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, dataset = dataset_ops.Dataset.zip((dataset1, dataset2, dataset3, dataset4)) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) - next1 = self.getNext(dataset) - for i in range(0, 1000): - self.assertEqual((i, i + 1000, i + 2000, i + 3000), - self.evaluate(next1())) + + expected = list( + zip( + range(0, 1000), range(1000, 2000), range(2000, 3000), + range(3000, 4000))) + self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, diff --git a/tensorflow/python/data/experimental/service/BUILD b/tensorflow/python/data/experimental/service/BUILD index b6096b4b0bb..020e5a91e29 100644 --- a/tensorflow/python/data/experimental/service/BUILD +++ b/tensorflow/python/data/experimental/service/BUILD @@ -36,7 +36,6 @@ py_library( tf_py_test( name = "server_lib_test", srcs = ["server_lib_test.py"], - tfrt_enabled = True, deps = [ ":server_lib", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 95179a4a7df..addd20fb73b 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -217,7 +217,7 @@ class DispatchServer(object): class WorkerConfig( collections.namedtuple("WorkerConfig", [ "dispatcher_address", "worker_address", "port", "protocol", - "heartbeat_interval_ms" + "heartbeat_interval_ms", "dispatcher_timeout_ms" ])): """Configuration class for tf.data service dispatchers. @@ -235,6 +235,8 @@ class WorkerConfig( reasonable default. A higher value will reduce the load on the dispatcher, while a lower value will reduce the time it takes to reclaim resources from finished jobs. + dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the + dispatcher before giving up and reporting an error. Defaults to 1 hour. """ def __new__(cls, @@ -242,15 +244,19 @@ class WorkerConfig( worker_address=None, port=0, protocol="grpc", - heartbeat_interval_ms=None): + heartbeat_interval_ms=None, + dispatcher_timeout_ms=None): if worker_address is None: worker_address = "localhost:%port%" if heartbeat_interval_ms is None: heartbeat_interval_ms = 30 * 1000 # 30 seconds + if dispatcher_timeout_ms is None: + dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour return super(WorkerConfig, cls).__new__(cls, dispatcher_address, worker_address, port, - protocol, heartbeat_interval_ms) + protocol, heartbeat_interval_ms, + dispatcher_timeout_ms) @tf_export("data.experimental.service.WorkerServer", v1=[]) @@ -299,7 +305,8 @@ class WorkerServer(object): worker_address=config.worker_address, port=config.port, protocol=config.protocol, - heartbeat_interval_ms=config.heartbeat_interval_ms) + heartbeat_interval_ms=config.heartbeat_interval_ms, + dispatcher_timeout_ms=config.dispatcher_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 2705778fba6..b3e07df10b1 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -114,7 +114,6 @@ tf_py_test( name = "dataset_spec_test", size = "small", srcs = ["dataset_spec_test.py"], - tfrt_enabled = True, deps = [ ":test_base", "//tensorflow/python:client_testlib", @@ -226,7 +225,6 @@ tf_py_test( name = "from_sparse_tensor_slices_test", size = "small", srcs = ["from_sparse_tensor_slices_test.py"], - tfrt_enabled = True, deps = [ ":test_base", "//tensorflow/core:protos_all_py", @@ -318,7 +316,6 @@ tf_py_test( "no_oss", # Test flaky due to port collisions. "no_windows", ], - tfrt_enabled = True, deps = [ ":test_base", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index a7d2ed4840f..8ebc5f140a3 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -833,14 +833,20 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, output_signature = nest.map_structure_up_to(output_types, tensor_spec.TensorSpec, output_shapes, output_types) + if all([ + isinstance(x, tensor_spec.TensorSpec) + for x in nest.flatten(output_signature) + ]): + output_types = nest.pack_sequence_as( + output_signature, [x.dtype for x in nest.flatten(output_signature)]) + output_shapes = nest.pack_sequence_as( + output_signature, [x.shape for x in nest.flatten(output_signature)]) if args is None: args = () else: args = tuple(ops.convert_n_to_tensor(args, name="args")) - flat_output_types = structure.get_flat_tensor_types(output_signature) - generator_state = DatasetV2._GeneratorState(generator) def get_iterator_id_fn(unused_dummy): @@ -872,38 +878,112 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, Returns: The next element to generate from the iterator. """ + if output_types and output_shapes: + flattened_types = [ + dtypes.as_dtype(dt) for dt in nest.flatten(output_types) + ] + flattened_shapes = nest.flatten(output_shapes) - def generator_py_func(iterator_id): - """A `py_func` that will be called to invoke the iterator.""" - # `next()` raises `StopIteration` when there are no more - # elements remaining to be generated. - values = next(generator_state.get_iterator(iterator_id.numpy())) + def generator_py_func(iterator_id): + """A `py_func` that will be called to invoke the iterator.""" + # `next()` raises `StopIteration` when there are no more + # elements remaining to be generated. + values = next(generator_state.get_iterator(iterator_id)) - try: - values = structure.normalize_element(values, output_signature) - except (TypeError, ValueError): - six.reraise( - TypeError, - TypeError( - "`generator` yielded an element that did not match the " - "expected structure. The expected structure was %s, but the " - "yielded element was %s." % (output_signature, values)), - sys.exc_info()[2]) + # Use the same _convert function from the py_func() implementation to + # convert the returned values to arrays early, so that we can inspect + # their values. + try: + flattened_values = nest.flatten_up_to(output_types, values) + except (TypeError, ValueError): + six.reraise( + TypeError, + TypeError( + "`generator` yielded an element that did not match the " + "expected structure. The expected structure was %s, but " + "the yielded element was %s." % (output_types, values)), + sys.exc_info()[2]) + ret_arrays = [] + for ret, dtype in zip(flattened_values, flattened_types): + try: + ret_arrays.append( + script_ops.FuncRegistry._convert( # pylint: disable=protected-access + ret, + dtype=dtype.as_numpy_dtype)) + except (TypeError, ValueError): + six.reraise( + TypeError, + TypeError( + "`generator` yielded an element that could not be " + "converted to the expected type. The expected type was " + "%s, but the yielded element was %s." % + (dtype.name, ret)), + sys.exc_info()[2]) - values_spec = structure.type_spec_from_value(values) + # Additional type and shape checking to ensure that the components of + # the generated element match the `output_types` and `output_shapes` + # arguments. + for (ret_array, expected_dtype, + expected_shape) in zip(ret_arrays, flattened_types, + flattened_shapes): + if ret_array.dtype != expected_dtype.as_numpy_dtype: + raise TypeError( + "`generator` yielded an element of type %s where an element " + "of type %s was expected." % + (ret_array.dtype, expected_dtype.as_numpy_dtype)) + if not expected_shape.is_compatible_with(ret_array.shape): + raise ValueError( + "`generator` yielded an element of shape %s where an element " + "of shape %s was expected." % + (ret_array.shape, expected_shape)) - if not structure.are_compatible(values_spec, output_signature): - raise TypeError( - "`generator` yielded an element of %s where an element " - "of %s was expected." % (values_spec, output_signature)) + return ret_arrays - return structure.to_tensor_list(output_signature, values) + flat_values = script_ops.numpy_function(generator_py_func, + [iterator_id_t], + flattened_types) - return script_ops._eager_py_func( # pylint: disable=protected-access - generator_py_func, - inp=[iterator_id_t], - Tout=flat_output_types, - use_tape_cache=False) + # The `py_func()` op drops the inferred shapes, so we add them back in + # here. + if output_shapes is not None: + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) + else: + flat_output_types = structure.get_flat_tensor_types(output_signature) + + def generator_py_func(iterator_id): + """A `py_func` that will be called to invoke the iterator.""" + # `next()` raises `StopIteration` when there are no more + # elements remaining to be generated. + values = next(generator_state.get_iterator(iterator_id.numpy())) + + try: + values = structure.normalize_element(values, output_signature) + except (TypeError, ValueError): + six.reraise( + TypeError, + TypeError( + "`generator` yielded an element that did not match the " + "expected structure. The expected structure was %s, but " + "the yielded element was %s." % (output_signature, values)), + sys.exc_info()[2]) + + values_spec = structure.type_spec_from_value(values) + + if not structure.are_compatible(values_spec, output_signature): + raise TypeError( + "`generator` yielded an element of %s where an element " + "of %s was expected." % (values_spec, output_signature)) + + return structure.to_tensor_list(output_signature, values) + + return script_ops._eager_py_func( # pylint: disable=protected-access + generator_py_func, + inp=[iterator_id_t], + Tout=flat_output_types, + use_tape_cache=False) def finalize_fn(iterator_id_t): """Releases host-side state for the iterator with ID `iterator_id_t`.""" @@ -1990,7 +2070,7 @@ name=None)) stride of the input elements in the sliding window. Must be positive. The default value of 1 means "retain every input element". drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing - whether the last window should be dropped if its size is smaller than + whether the last windows should be dropped if their size is smaller than `size`. Returns: diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 176d2546525..36fdb20aeae 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -797,7 +797,6 @@ cuda_py_test( size = "small", srcs = ["lib/debug_gradients_test.py"], python_version = "PY3", - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":debug_data", @@ -1042,7 +1041,6 @@ cuda_py_test( size = "small", srcs = ["lib/debug_grappler_test.py"], python_version = "PY3", - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Tests TF:Classic implementation. deps = [ ":debug_data", @@ -1061,7 +1059,6 @@ cuda_py_test( srcs = ["lib/session_debug_file_test.py"], python_version = "PY3", tags = ["notsan"], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":debug_data", @@ -1080,7 +1077,6 @@ cuda_py_test( size = "small", srcs = ["lib/debug_graph_reconstruction_test.py"], python_version = "PY3", - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":debug_data", @@ -1101,7 +1097,6 @@ cuda_py_test( srcs = ["lib/session_debug_multi_gpu_test.py"], python_version = "PY3", tags = ["no_windows_gpu"], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":debug_data", @@ -1224,7 +1219,6 @@ cuda_py_test( srcs = ["cli/analyzer_cli_test.py"], python_version = "PY3", tags = ["no_windows"], # TODO: needs investigation on Windows - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":analyzer_cli", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index c9599747f3f..56693bd0033 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -68,10 +68,12 @@ py_library( ":collective_util", ":cross_device_utils", ":device_util", + ":distribute_utils", ":ps_values", ":reduce_util", ":tpu_values", ":values", + ":values_util", "//tensorflow/python:array_ops", "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", @@ -82,8 +84,10 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:tf_export", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:executor", "//tensorflow/tools/docs:doc_controls", + "@enum34_archive//:enum", "@six_archive//:six", ], ) @@ -149,8 +153,9 @@ py_library( ":multi_process_runner", ":multi_worker_test_base", ":one_device_strategy", + ":parameter_server_strategy_v2", ":sharded_variable", - "//tensorflow/python/distribute/client", + "//tensorflow/python/distribute/coordinator:cluster_coordinator", "//tensorflow/python/distribute/experimental", ], ) @@ -333,6 +338,7 @@ py_library( name = "mirrored_strategy", srcs = ["mirrored_strategy.py"], deps = [ + ":collective_util", ":cross_device_ops", ":device_util", ":distribute_lib", @@ -421,18 +427,27 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ + ":collective_util", ":cross_device_ops", ":cross_device_utils", + ":device_util", + ":distribute_lib", + ":distribute_utils", ":input_lib", ":mirrored_strategy", ":multi_worker_util", ":numpy_dataset", + ":reduce_util", ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", ], @@ -457,6 +472,7 @@ cuda_py_test( ], python_version = "PY3", tags = [ + "notap", # TODO(b/171355671) "notsan", # TODO(b/151841995) ], deps = [ @@ -665,12 +681,22 @@ py_library( py_library( name = "collective_util", srcs = ["collective_util.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) +tf_py_test( + name = "collective_util_test", + srcs = ["collective_util_test.py"], + deps = [ + ":collective_util", + "//tensorflow/python/eager:test", + ], +) + py_library( name = "shared_variable_creator", srcs = ["shared_variable_creator.py"], @@ -908,6 +934,9 @@ tf_py_test( name = "multi_worker_test_base_test", srcs = ["multi_worker_test_base_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_oss", # TODO(b/170834611) + ], deps = [ ":multi_worker_test_base", ], @@ -935,7 +964,6 @@ cuda_py_test( distribute_py_test( name = "checkpointing_test", srcs = ["checkpointing_test.py"], - disable_mlir_bridge = False, main = "checkpointing_test.py", tags = [ "multi_and_single_gpu", @@ -990,22 +1018,24 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ - ":collective_all_reduce_strategy", ":combinations", - ":input_lib", - ":mirrored_strategy", - ":multi_worker_test_base", - ":reduce_util", + ":distribute_lib", ":strategy_combinations", + ":test_util", ":tpu_strategy", ":values", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:dtypes", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:tf2", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", "//tensorflow/python/ops/ragged:ragged_tensor", "//third_party/py/numpy", @@ -1017,7 +1047,6 @@ cuda_py_test( name = "cross_device_utils_test", srcs = ["cross_device_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":combinations", ":cross_device_utils", @@ -1044,6 +1073,7 @@ cuda_py_test( ":collective_util", ":combinations", ":cross_device_ops", + ":cross_device_utils", ":multi_process_runner", ":multi_worker_test_base", ":reduce_util", @@ -1083,9 +1113,15 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:partitioned_variables", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tf_export", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/saved_model:save_context", "//tensorflow/python/training/saving:saveable_object_util", @@ -1097,19 +1133,28 @@ tf_py_test( name = "sharded_variable_test", size = "small", srcs = ["sharded_variable_test.py"], + tags = [ + # depend through //third_party/tensorflow/python:extra_py_tests_deps. + "ignore_for_dep=third_party.tensorflow.python.keras.engine.base_layer", + ], deps = [ ":sharded_variable", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:embedding_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework_ops", "//tensorflow/python:session", + "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", + "//tensorflow/python/module", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:signature_constants", @@ -1153,7 +1198,6 @@ distribute_py_test( name = "values_test", size = "medium", srcs = ["values_test.py"], - disable_mlir_bridge = False, main = "values_test.py", shard_count = 5, tags = [ @@ -1162,9 +1206,11 @@ distribute_py_test( "notsan", # b/168645872 ], tpu_tags = [ - "no_oss", # b/150954621 Target too big to run serially reliably. + "no_oss", # TODO(b/150954621) Target too big to run serially reliably. + "noasan", # TODO(b/337374867) fails with -fsanitize=null ], deps = [ + ":collective_all_reduce_strategy", ":combinations", ":distribute_lib", ":distribute_utils", @@ -1176,7 +1222,6 @@ distribute_py_test( ":tpu_strategy", ":tpu_values", ":values", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", @@ -1186,7 +1231,6 @@ distribute_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:indexed_slices", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:saver", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", @@ -1197,14 +1241,12 @@ distribute_py_test( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:save_context", "//tensorflow/python/saved_model:save_options", - "//tensorflow/python/saved_model/model_utils:mode_keys", - "//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/types", "@absl_py//absl/testing:parameterized", ], @@ -1301,7 +1343,6 @@ distribute_py_test( distribute_py_test( name = "moving_averages_test", srcs = ["moving_averages_test.py"], - disable_mlir_bridge = False, main = "moving_averages_test.py", deps = [ ":combinations", @@ -1650,6 +1691,7 @@ py_test( srcs = ["multi_process_runner_test.py"], python_version = "PY3", shard_count = 12, + tags = ["no_oss_py38"], #TODO(b/171435331) deps = [ ":multi_process_runner", ":multi_worker_test_base", @@ -1738,7 +1780,7 @@ distribute_py_test( srcs = ["strategy_gather_test.py"], disable_mlir_bridge = False, python_version = "PY3", - shard_count = 2, + shard_count = 4, tags = [ "multi_and_single_gpu", "notsan", # TODO(b/160006974) @@ -1852,18 +1894,19 @@ tf_py_test( srcs = ["parameter_server_strategy_v2_test.py"], python_version = "PY3", tags = [ - "notsan", # b/168675975 + "no_windows", # TODO(171349346) + "notsan", # TODO(b/168675975) ], deps = [ ":multi_worker_test_base", ":parameter_server_strategy_v2", ":sharded_variable", + "//tensorflow:tensorflow_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:init_ops_v2", "//tensorflow/python:linalg_ops_impl", - "//tensorflow/python:partitioned_variables", "//tensorflow/python:variables", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/distribute/central_storage_strategy.py b/tensorflow/python/distribute/central_storage_strategy.py index e61570dd6bd..caa03f84193 100644 --- a/tensorflow/python/distribute/central_storage_strategy.py +++ b/tensorflow/python/distribute/central_storage_strategy.py @@ -102,6 +102,13 @@ class CentralStorageStrategy(distribute_lib.Strategy): Returns: A "distributed `Dataset`" that the caller can iterate over. """ + if (options and options.experimental_replication_moden == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + 'InputReplicationMode.PER_REPLICA ' + 'is only supported in ' + '`experimental_distribute_datasets_from_function`.' + ) return super(CentralStorageStrategy, self).experimental_distribute_dataset( dataset, options) diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py deleted file mode 100644 index 981ad964b6d..00000000000 --- a/tensorflow/python/distribute/client/client_test.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for client.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import platform -import sys -import threading -import time - -from tensorflow.python.distribute.client import client -from tensorflow.python.eager import cancellation -from tensorflow.python.eager import def_function -from tensorflow.python.framework import errors -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import coordinator -from tensorflow.python.util import nest - - -class CoordinatedClosureQueueTest(test.TestCase): - - def testBasic(self): - queue = client._CoordinatedClosureQueue() - closure1 = self._create_closure(queue._cancellation_mgr) - queue.put(closure1) - self.assertIs(closure1, queue.get()) - self.assertFalse(queue.done()) - queue.put_back(closure1) - self.assertEqual(closure1, queue.get()) - queue.mark_finished() - self.assertTrue(queue.done()) - queue.wait() - - def testProcessAtLeaseOnce(self): - closure_queue = client._CoordinatedClosureQueue() - labels = ['A', 'B', 'C', 'D', 'E'] - processed_count = collections.defaultdict(int) - - coord = coordinator.Coordinator(clean_stop_exception_types=[]) - - def process_queue(): - with coord.stop_on_exception(): - has_been_put_back = False - while True: - closure = closure_queue.get(timeout=30) - if closure is None: - break - if not has_been_put_back: - has_been_put_back = True - closure_queue.put_back(closure) - continue - closure._function() - closure_queue.mark_finished() - - def get_func(label): - - def func(): - time.sleep(3) - processed_count[label] += 1 - - return func - - cm = cancellation.CancellationManager() - for label in labels: - closure_queue.put(client.Closure(get_func(label), cm)) - t1 = threading.Thread(target=process_queue, daemon=True) - t1.start() - t2 = threading.Thread(target=process_queue, daemon=True) - t2.start() - - # Make sure multiple wait() calls are fine. - closure_queue.wait() - closure_queue.wait() - closure_queue.wait() - closure_queue.wait() - - self.assertEqual(processed_count, collections.Counter(labels)) - - coord.join([t1, t2]) - - def testNotifyBeforeWait(self): - closure_queue = client._CoordinatedClosureQueue() - - def func(): - logging.info('func running') - - coord = coordinator.Coordinator(clean_stop_exception_types=[]) - - def process_queue(): - with coord.stop_on_exception(): - closure_queue.get() - closure_queue.mark_finished() - - closure_queue.put(client.Closure(func, closure_queue._cancellation_mgr)) - t = threading.Thread(target=process_queue) - t.start() - coord.join([t]) - - # This test asserts that waiting at the time the function has been processed - # doesn't time out. - closure_queue.wait() - - def _assert_one_unblock_the_other(self, first_fn, second_fn): - """Asserts `second_fn` wouldn't return before `first_fn` is finished.""" - first_fn_done = threading.Event() - second_fn_done = threading.Event() - coord = coordinator.Coordinator(clean_stop_exception_types=[]) - - def wrapped_first_fn(): - with coord.stop_on_exception(): - self.assertFalse(second_fn_done.is_set()) - first_fn() - first_fn_done.set() - - self.assertFalse(first_fn_done.is_set()) - t = threading.Thread(target=wrapped_first_fn) - t.start() - - second_fn() - self.assertTrue(first_fn_done.is_set()) - second_fn_done.set() - - coord.join([t]) - - def testWaitRaiseErrorAfterMarkFailure(self): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue = client._CoordinatedClosureQueue() - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - closure = closure_queue.get() - - wait_finish_event = threading.Event() - coord = coordinator.Coordinator(clean_stop_exception_types=[]) - - # Using a thread to verify that closure_queue.wait() will not return until - # all inflight closures are finished. - - def mark_finished_fn(): - try: - raise ValueError('Some error.') - except ValueError as e: - closure_queue.mark_failed(e) - - def wait_fn(): - with self.assertRaises(ValueError): - closure_queue.wait() - - self._assert_one_unblock_the_other(mark_finished_fn, wait_fn) - - self.assertTrue(closure_queue.done()) - - def _create_closure(self, cancellation_mgr): - - @def_function.function() - def some_function(): - return 1.0 - - return client.Closure(some_function, cancellation_mgr) - - def _put_two_closures_and_get_one(self): - closure_queue = client._CoordinatedClosureQueue() - closure1 = self._create_closure(closure_queue._cancellation_mgr) - closure_queue.put(closure1) - - closure2 = self._create_closure(closure_queue._cancellation_mgr) - closure_queue.put(closure2) - - closure_got = closure_queue.get() # returns closure1 - self.assertIs(closure_got, closure1) - self.assertIsNot(closure_got, closure2) - return closure_queue, closure1, closure2 - - def testPutRaiseError(self): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue, _, closure2 = self._put_two_closures_and_get_one() - - closure_queue.mark_failed(ValueError()) - - with self.assertRaises(ValueError): - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - - self.assertTrue(closure_queue.done()) - - with self.assertRaisesRegex( - errors.CancelledError, - 'The corresponding function is cancelled. Please reschedule the ' - 'function.'): - closure2._fetch_output_remote_values() - - # The error is cleared. - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - - def testWaitRaiseError(self): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue, _, closure2 = self._put_two_closures_and_get_one() - - closure_queue.mark_failed(ValueError()) - - with self.assertRaises(ValueError): - closure_queue.wait() - self.assertTrue(closure_queue.done()) - - with self.assertRaisesRegex( - errors.CancelledError, - 'The corresponding function is cancelled. Please reschedule the ' - 'function.'): - closure2._fetch_output_remote_values() - - # The error is cleared. - closure_queue.wait() - - def testDoneRaiseError(self): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue, _, _ = self._put_two_closures_and_get_one() - - self.assertFalse(closure_queue.done()) - closure_queue.mark_failed(ValueError()) - with self.assertRaises(ValueError): - closure_queue.done() - - def _set_error(self, closure_queue, closure, error): - try: - raise error - except Exception as e: # pylint: disable=broad-except - nest.map_structure(lambda x: x._set_error(e), - closure._output_remote_values) - closure_queue.mark_failed(e) - - def _test_cancel_closure_when_error(self, call_wait): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue, closure1, closure2 = self._put_two_closures_and_get_one() - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - closure_queue.get() - # At this moment, there are two inflight, one in queue. - self.assertEqual(closure_queue._inflight_closure_count, 2) - - # Hold a copy of the queue's cancellation manager at this point - initial_cm = closure_queue._cancellation_mgr - - # Simulating closure1 fails. - self._set_error(closure_queue, closure1, ValueError('Some error.')) - - # At this moment, there are one inflight, one in queue. - self.assertEqual(closure_queue._queue.qsize(), 1) - self.assertEqual(closure_queue._inflight_closure_count, 1) - - closure3 = self._create_closure(closure_queue._cancellation_mgr) - - def fake_cancellation(): - self._set_error(closure_queue, closure2, - ValueError('Fake cancellation error.')) - - def report_error(): - # It should not report the fake cancellation error. - with self.assertRaisesRegex(ValueError, 'Some error.'): - # Verifying `wait()` or `put()` raises even if one closure is in - # flight. - if call_wait: - closure_queue.wait() - else: - closure_queue.put(closure3) - - self._assert_one_unblock_the_other(fake_cancellation, report_error) - - # The original cancellation manager of the queue has been cancelled. - self.assertTrue(initial_cm.is_cancelled) - - # At this moment, there is zero inflight, nothing in queue. - self.assertTrue(closure_queue._queue.empty()) - self.assertEqual(closure_queue._inflight_closure_count, 0) - self.assertIsNone(closure_queue._error) - - # This asserts that closure1 has errored. - with self.assertRaisesRegex(ValueError, 'Some error.'): - closure1._fetch_output_remote_values() - - # The following asserts that closure3 should have been cancelled. - if not call_wait: - with self.assertRaisesRegex( - errors.CancelledError, - 'The corresponding function is cancelled. Please reschedule the ' - 'function.'): - closure3._fetch_output_remote_values() - - # Closure2 was an inflight closure when it got cancelled. - self.assertEqual(closure2._output_remote_values._status, - client._RemoteValueStatus.READY) - with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'): - closure2._fetch_output_remote_values() - - # This asserts that the queue has a clear state. - self.testBasic() - - def testWaitRaiseErrorAfterCancelClosure(self): - self._test_cancel_closure_when_error(call_wait=True) - - def testPutRaiseErrorAfterCancelClosure(self): - self._test_cancel_closure_when_error(call_wait=False) - - def testStateIsRestoredAfterJoinIsCalled(self): - if sys.version_info >= (3, 8) and platform.system() == 'Windows': - # TODO(b/165013260): Fix this - self.skipTest('Test is currently broken on Windows with Python 3.8') - - closure_queue, _, _ = self._put_two_closures_and_get_one() - self.assertEqual(closure_queue._inflight_closure_count, 1) - closure_queue.mark_failed(ValueError('test error')) - with self.assertRaises(ValueError): - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - - # Its error should have been cleared. - self.assertIsNone(closure_queue._error) - closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) - self.assertIsNone(closure_queue._error) - - def testThreadSafey(self): - thread_count = 10 - queue = client._CoordinatedClosureQueue() - - # Each thread performs 20 queue actions: 10 are `put_back` and 10 are - # `mark_finished`. - action_count = 20 - - def func(): - for i in range(action_count): - closure = queue.get() - if i % 2 == 0: - queue.put_back(closure) - else: - queue.mark_finished() - - threads = [threading.Thread(target=func) for i in range(thread_count)] - for t in threads: - t.start() - - for _ in range(thread_count * action_count // 2): - queue.put(self._create_closure(queue._cancellation_mgr)) - queue.wait() - self.assertTrue(queue.done()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py deleted file mode 100644 index 022539308d1..00000000000 --- a/tensorflow/python/distribute/client/parameter_server_client_test.py +++ /dev/null @@ -1,471 +0,0 @@ -# Lint as: python3 -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `Client` when used together with `ParameterServerStrategyV2.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import threading - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.client import client as client_lib -from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver -from tensorflow.python.eager import def_function -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import tensor_spec -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.server_lib import ClusterSpec - - -class ErrorReportingThread(threading.Thread): - - error = None - - def __init__(self, *args, **kwargs): - assert "target" in kwargs - target = kwargs["target"] - - @functools.wraps(target) - def wrapped_target(*args, **kwargs): - try: - return target(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except - ErrorReportingThread.error = e - - kwargs["target"] = wrapped_target - super(ErrorReportingThread, self).__init__(*args, **kwargs) - - -class TestCaseWithErrorReportingThread(test.TestCase): - - @classmethod - def setUpClass(cls): - cls._threading_thread = threading.Thread - threading.Thread = ErrorReportingThread - super(TestCaseWithErrorReportingThread, cls).setUpClass() - - @classmethod - def tearDownClass(cls): - super(TestCaseWithErrorReportingThread, cls).tearDownClass() - threading.Thread = cls._threading_thread - - def setUp(self): - ErrorReportingThread.error = None - super(TestCaseWithErrorReportingThread, self).setUp() - - def tearDown(self): - super(TestCaseWithErrorReportingThread, self).tearDown() - if ErrorReportingThread.error: - raise ErrorReportingThread.error # pylint: disable=raising-bad-type - - -def make_client(num_workers, num_ps): - # TODO(rchao): Test the internal rpc_layer version. - cluster_def = multi_worker_test_base.create_in_process_cluster( - num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") - cluster_def["chief"] = [ - "localhost:%d" % multi_worker_test_base.pick_unused_port() - ] - cluster_resolver = SimpleClusterResolver( - ClusterSpec(cluster_def), rpc_layer="grpc") - strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - cluster_resolver) - return client_lib.Client(strategy) - - -class ParameterServerClientTest(TestCaseWithErrorReportingThread): - - @classmethod - def setUpClass(cls): - super(ParameterServerClientTest, cls).setUpClass() - cls.client = make_client(num_workers=3, num_ps=2) - cls.strategy = cls.client.strategy - - def testBasic(self): - self.strategy.extended._variable_count = 0 - with self.strategy.scope(): - v1 = variables.Variable(initial_value=0.0) - v2 = variables.Variable(initial_value=1.0) - self.assertEqual(self.strategy.extended._variable_count, 2) - - @def_function.function - def worker_fn(): - v1.assign_add(0.1) - v2.assign_sub(0.2) - return v1.read_value() / v2.read_value() - - results = self.client.schedule(worker_fn) - logging.info("Results of experimental_run_v2: %f", - self.client.fetch(results)) - - self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6) - self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6) - - def testFnReturnNestedValues(self): - x = constant_op.constant(1) - - @def_function.function - def f(): - return x + 1, (x + 2, x + 3), [x + 4], {"v": x} - - got = self.client.schedule(f) - want = 2, (3, 4), [5], {"v": 1} - self.assertEqual(self.client.fetch(got), want) - - def testInputFunction(self): - - def input_fn(): - return dataset_ops.DatasetV2.range(1, 2) - - with self.strategy.scope(): - v = variables.Variable(initial_value=0, dtype=dtypes.int64) - - @def_function.function - def worker_fn(iterator): - x = next(iterator) - v.assign_add(x) - return x - - distributed_dataset = self.client.create_per_worker_dataset(input_fn) - result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),)) - result = self.client.fetch(result) - self.assertEqual(result, (1,)) - result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),)) - result = self.client.fetch(result) - self.assertEqual(result, (1,)) - - self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6) - - def testAsyncScheduleAndJoin(self): - - def input_fn(): - return dataset_ops.DatasetV2.from_tensor_slices([2] * 10) - - with self.strategy.scope(): - v = variables.Variable(initial_value=0, dtype=dtypes.int32) - - # TODO(yuefengz): the following tf.function has a return value which is None - # in its structured_outputs. - @def_function.function - def worker_fn(iterator): - x = next(iterator) - v.assign_add(x) - - distributed_dataset = self.client.create_per_worker_dataset(input_fn) - - iterator = iter(distributed_dataset) - - # Verifying joining without any scheduling doesn't hang. - self.client.join() - self.assertEqual(v.read_value().numpy(), 0) - - for _ in range(5): - self.client.schedule(worker_fn, args=(iterator,)) - self.client.join() - - # With 5 addition it should be 2*5 = 10. - self.assertEqual(v.read_value().numpy(), 10) - - for _ in range(5): - self.client.schedule(worker_fn, args=(iterator,)) - - # Verifying multiple join is fine. - self.client.join() - self.client.join() - self.client.join() - - self.assertTrue(self.client.done()) - - # Likewise, it's now 20. - self.assertEqual(v.read_value().numpy(), 20) - - def testInputFunctionWithMap(self): - self._map_fn_tracing_count = 0 - - def input_fn(): - def map_fn(x): - self._map_fn_tracing_count += 1 - return x + 10 - return dataset_ops.DatasetV2.range(0, 10).map(map_fn) - - @def_function.function - def worker_fn(iterator): - return next(iterator) - - distributed_dataset = ( - self.client.create_per_worker_dataset(input_fn)) - result = self.client.schedule( - worker_fn, args=(iter(distributed_dataset),)) - self.assertEqual(result.fetch(), (10,)) - self.assertEqual(self._map_fn_tracing_count, 1) - - def testInputFunctionCreateVariables(self): - - def input_fn(): - v = variables.Variable(initial_value=0.0) - return v.read_value() - - with self.assertRaises(ValueError): - self.client.create_per_worker_dataset(input_fn) - - def testPerWorkerValue(self): - var_shape = tuple() - var_dtype = dtypes.float32 - var_name = "var" - - def create_var(): - var = variables.Variable( - initial_value=0.0, dtype=var_dtype, name=var_name) - self.assertIn("worker", var.device) - return var - - worker_local_var = self.client._create_per_worker_resources(create_var) - - # The following is a workaround to allow `worker_local_var` to be passed in - # as args to the `client.schedule` method which requires tensor specs to - # trace tf.function but _create_worker_resources' return values don't have - # tensor specs. We can get rid of this workaround once - # _create_worker_resources is able to infer the tensor spec of the return - # value of the function passed in. See b/154675763. - for var in worker_local_var._values: - var._set_type_spec(tensor_spec.TensorSpec(var_shape, var_dtype, var_name)) - - def worker_fn(var): - var.assign_add(1.0) - - for _ in range(10): - # Which slice of `worker_local_var` will be used will depend on which - # worker the `worker_fn` gets scheduled on. - self.client.schedule(worker_fn, args=(worker_local_var,)) - self.client.join() - - var_sum = sum(self.client.fetch(worker_local_var._values)) - self.assertEqual(var_sum, 10.0) - - def testDisallowRemoteValueAsInput(self): - - @def_function.function - def func_0(): - return 1.0 - - @def_function.function - def func_1(x): - return x + 1.0 - - remote_v = self.client.schedule(func_0) - with self.assertRaises(ValueError): - self.client.schedule(func_1, args=(remote_v,)) - - -class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest): - """Test basic functionality works with explicit maximum closure queue size. - - Execute the same set of test cases as in `ParameterServerClientTest`, with an - explicit size limit for the closure queue. Note that even when the queue size - is set to infinite, there is still a maximum practical size (depends on host - memory limit) that might cause the queue.put operations to be blocking when - scheduling a large number of closures on a big cluster. These tests make sure - that the client does not run into deadlocks in such scenario. - """ - - @classmethod - def setUpClass(cls): - super(LimitedClosureQueueSizeBasicTest, cls).setUpClass() - client_lib._CLOSURE_QUEUE_MAX_SIZE = 2 - cls.client = make_client(num_workers=3, num_ps=2) - cls.strategy = cls.client.strategy - - -class ErrorReportingTest(TestCaseWithErrorReportingThread): - - @classmethod - def setUpClass(cls): - super(ErrorReportingTest, cls).setUpClass() - cls.client = make_client(num_workers=3, num_ps=2) - cls.strategy = cls.client.strategy - - with cls.strategy.scope(): - cls.iteration = variables.Variable(initial_value=0.0) - - @def_function.function - def _normal_function(self): - x = random_ops.random_uniform((2, 10)) - y = random_ops.random_uniform((10, 2)) - self.iteration.assign_add(1.0) - return math_ops.reduce_mean(math_ops.matmul(x, y)) - - @def_function.function - def _error_function(self): - x = random_ops.random_uniform((2, 10)) - y = random_ops.random_uniform((10, 2)) - check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y))) - self.iteration.assign_add(1.0) - return self.iteration - - @def_function.function - def _long_function(self): - x = random_ops.random_uniform((1000, 1000)) - for _ in math_ops.range(10000): - a = random_ops.random_uniform((1000, 1000)) - b = random_ops.random_uniform((1000, 1000)) - x += math_ops.matmul(a, b) - return x - - def testJoinRaiseError(self): - for _ in range(3): - self.client.schedule(self._normal_function) - self.client.schedule(self._error_function) - with self.assertRaises(errors.InvalidArgumentError): - self.client.join() - - def testScheduleRaiseError(self): - for _ in range(3): - self.client.schedule(self._normal_function) - self.client.schedule(self._error_function) - with self.assertRaises(errors.InvalidArgumentError): - while True: - self.client.schedule(self._normal_function) - - def testScheduleRaiseErrorWithMultipleFailure(self): - for _ in range(3): - self.client.schedule(self._normal_function) - self.client.schedule(self._error_function) - with self.assertRaises(errors.InvalidArgumentError): - while True: - self.client.schedule(self._error_function) - self.client.join() - - def testErrorWillbeCleared(self): - self.client.schedule(self._error_function) - with self.assertRaises(errors.InvalidArgumentError): - self.client.join() - - for _ in range(3): - self.client.schedule(self._normal_function) - self.client.schedule(self._error_function) - with self.assertRaises(errors.InvalidArgumentError): - self.client.join() - - def testRemoteValueReturnError(self): - result = self.client.schedule(self._error_function) - - with self.assertRaises(errors.InvalidArgumentError): - result.fetch() - - # Clear the error. - with self.assertRaises(errors.InvalidArgumentError): - self.client.join() - - def testInputError(self): - - worker_local_val = self.client._create_per_worker_resources( - self._error_function) - - @def_function.function - def func(x): - return x + 1 - - result = self.client.schedule(func, args=(worker_local_val,)) - with self.assertRaises(client_lib.InputError): - self.client.join() - - with self.assertRaises(client_lib.InputError): - result.fetch() - - def testCancellation(self): - for _ in range(3): - self.client.schedule(self._normal_function) - long_function = self.client.schedule(self._long_function) - self.client.schedule(self._error_function) - - with self.assertRaises(errors.InvalidArgumentError): - self.client.join() - - with self.assertRaises(errors.CancelledError): - long_function.fetch() - - for _ in range(3): - self.client.schedule(self._normal_function) - self.client.join() - - -class LimitedClosureQueueErrorTest(ErrorReportingTest): - """Test error reporting works with explicit maximum closure queue size. - - Execute the same set of test cases as in ErrorReportingTest, with an explicit - size limit for the closure queue. - """ - - @classmethod - def setUpClass(cls): - super(LimitedClosureQueueErrorTest, cls).setUpClass() - client_lib._CLOSURE_QUEUE_MAX_SIZE = 2 - cls.client = make_client(num_workers=3, num_ps=2) - cls.strategy = cls.client.strategy - - with cls.client.strategy.scope(): - cls.iteration = variables.Variable(initial_value=0.0) - - -class StrategyRunTest(test.TestCase): - - @classmethod - def setUpClass(cls): - super(StrategyRunTest, cls).setUpClass() - cls.client = make_client(num_workers=1, num_ps=1) - cls.strategy = cls.client.strategy - - def testStrategyRun(self): - self.assertFalse(distribution_strategy_context.in_cross_replica_context()) - with self.strategy.scope(): - self.assertTrue(distribution_strategy_context.in_cross_replica_context()) - v = variables.Variable(initial_value=1) - - @def_function.function - def worker_fn(input_tensor): - - def replica_fn(input_tensor): - # Within `replica_fn`, it has to be in a replica context. - self.assertFalse( - distribution_strategy_context.in_cross_replica_context()) - return input_tensor + v - - return self.strategy.run(replica_fn, args=(input_tensor,)) - - # Asserting scheduling in scope has the expected behavior. - result = self.client.schedule(worker_fn, args=(constant_op.constant(3),)) - self.assertIsInstance(result, client_lib.RemoteValue) - self.assertEqual(result.fetch(), 4) - - # Asserting scheduling out of scope has the expected behavior. - result = self.client.schedule(worker_fn, args=(constant_op.constant(3),)) - self.assertEqual(result.fetch(), 4) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/distribute/cluster_resolver/BUILD b/tensorflow/python/distribute/cluster_resolver/BUILD index 6fda3f5311d..e925104ffa3 100644 --- a/tensorflow/python/distribute/cluster_resolver/BUILD +++ b/tensorflow/python/distribute/cluster_resolver/BUILD @@ -98,7 +98,6 @@ tf_py_test( name = "base_cluster_resolver_py_test", srcs = ["cluster_resolver_test.py"], main = "cluster_resolver_test.py", - tfrt_enabled = True, deps = [ ":base_cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -114,7 +113,6 @@ tf_py_test( size = "small", srcs = ["gce_cluster_resolver_test.py"], main = "gce_cluster_resolver_test.py", - tfrt_enabled = True, deps = [ ":gce_cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -131,7 +129,6 @@ tf_py_test( srcs = ["tfconfig_cluster_resolver_test.py"], grpc_enabled = True, main = "tfconfig_cluster_resolver_test.py", - tfrt_enabled = True, deps = [ ":tfconfig_cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -148,7 +145,6 @@ tf_py_test( srcs = ["sagemaker_cluster_resolver_test.py"], grpc_enabled = True, main = "sagemaker_cluster_resolver_test.py", - tfrt_enabled = True, deps = [ ":sagemaker_cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -165,7 +161,6 @@ tf_py_test( srcs = ["slurm_cluster_resolver_test.py"], main = "slurm_cluster_resolver_test.py", tags = [], - tfrt_enabled = True, deps = [ ":slurm_cluster_resolver_py", "//tensorflow/python:client_testlib", @@ -181,7 +176,6 @@ tf_py_test( size = "small", srcs = ["kubernetes_cluster_resolver_test.py"], main = "kubernetes_cluster_resolver_test.py", - tfrt_enabled = True, deps = [ ":kubernetes_cluster_resolver_py", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD index 01b21f73dee..dff2b6937f7 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/BUILD +++ b/tensorflow/python/distribute/cluster_resolver/tpu/BUILD @@ -32,7 +32,6 @@ tf_py_test( grpc_enabled = True, main = "tpu_cluster_resolver_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":tpu_cluster_resolver_py", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 5719fc6870a..37a440bf46e 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -37,6 +37,7 @@ from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context @@ -46,10 +47,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import base +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) +# pylint: disable=line-too-long +@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) class CollectiveAllReduceStrategy(distribute_lib.Strategy): """A distribution strategy for synchronous training on multiple workers. @@ -63,7 +66,12 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): `cluster_resolver` correctly. For example, if you are using `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` - environment variable. + environment variable. An example TF_CONFIG on worker-0 of a two worker cluster + is: + + ``` + TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' + ``` Your program runs on each worker as-is. Note that collectives require each worker to participate. All `tf.distribute` and non `tf.distribute` API may use @@ -76,8 +84,57 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined. - This strategy is not intended for TPU. Use - `tf.distribute.experimental.TPUStrategy` instead. + This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` + instead. + + After setting up TF_CONFIG, using this strategy is similar to using + `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. + + ``` + strategy = tf.distribute.MultiWorkerMirroredStrategy() + + with strategy.scope(): + model = tf.keras.Sequential([ + tf.keras.layers.Dense(2, input_shape=(5,)), + ]) + optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) + + def dataset_fn(ctx): + x = np.random.random((2, 5)).astype(np.float32) + y = np.random.randint(2, size=(2, 1)) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + return dataset.repeat().batch(1, drop_remainder=True) + dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) + + model.compile() + model.fit(dist_dataset) + ``` + + You can also write your own training loop: + + ``` + @tf.function + def train_step(iterator): + + def step_fn(inputs): + features, labels = inputs + with tf.GradientTape() as tape: + logits = model(features, training=True) + loss = tf.keras.losses.sparse_categorical_crossentropy( + labels, logits) + + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(grads, model.trainable_variables)) + + strategy.run(step_fn, args=(next(iterator),)) + + for _ in range(NUM_STEP): + train_step(iterator) + ``` + + See + [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) + for a detailed tutorial. __Saving__ @@ -98,33 +155,37 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): Tensorflow API. """ + # pylint: enable=line-too-long + # TODO(anjalisridhar): Update our guides with examples showing how we can use # the cluster_resolver argument. # The starting number for collective keys. This should only be set in tests. _collective_key_base = 0 - def __init__( - self, - communication=cross_device_ops_lib.CollectiveCommunication.AUTO, - cluster_resolver=None): + def __init__(self, + cluster_resolver=None, + communication_options=None): """Creates the strategy. Args: - communication: optional - `tf.distribute.experimental.CollectiveCommunication`. This is a hint on - the preferred collective communication implementation. Possible values - include `AUTO`, `RING`, and `NCCL`. cluster_resolver: optional `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. + communication_options: optional + `tf.distribute.experimental.CommunicationOptions`. This configures the + default options for cross device communications. It can be overridden by + options provided to the communication APIs like + `tf.distribute.ReplicaContext.all_reduce`. See + `tf.distribute.experimental.CommunicationOptions` for details. """ - # TODO(b/150151677): consider move communication to CollectiveHints. + if communication_options is None: + communication_options = collective_util.Options() super(CollectiveAllReduceStrategy, self).__init__( CollectiveAllReduceExtended( self, - communication=communication, - cluster_resolver=cluster_resolver)) + cluster_resolver=cluster_resolver, + communication_options=communication_options)) distribute_lib.distribution_strategy_gauge.get_cell("V2").set( "MultiWorkerMirroredStrategy") @@ -135,12 +196,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): "num_replicas_per_worker").set(self.extended._num_gpus_per_worker) @classmethod - def _from_local_devices( - cls, - devices, - communication=cross_device_ops_lib.CollectiveCommunication.AUTO): + def _from_local_devices(cls, devices, communication_options=None): """A convenience method to create an object with a list of devices.""" - obj = cls(communication) + obj = cls(communication_options=communication_options) obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access return obj @@ -157,21 +215,77 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): return self.extended._cluster_resolver # pylint: disable=protected-access +class _CollectiveAllReduceStrategyExperimentalMeta(type): + + @classmethod + def __instancecheck__(cls, instance): + # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), + # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is + # performing such check. + return isinstance(instance, CollectiveAllReduceStrategy) + + +@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) +class _CollectiveAllReduceStrategyExperimental( + CollectiveAllReduceStrategy, + metaclass=_CollectiveAllReduceStrategyExperimentalMeta): + + __doc__ = CollectiveAllReduceStrategy.__doc__ + + @deprecation.deprecated( + None, "use distribute.MultiWorkerMirroredStrategy instead") + def __init__(self, + communication=collective_util.CommunicationImplementation.AUTO, + cluster_resolver=None): + """Creates the strategy. + + Args: + communication: optional + `tf.distribute.experimental.CommunicationImplementation`. This is a hint + on the preferred collective communication implementation. Possible + values include `AUTO`, `RING`, and `NCCL`. + cluster_resolver: optional + `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, + `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. + """ + communication_options = collective_util.Options( + implementation=communication) + super(_CollectiveAllReduceStrategyExperimental, + self).__init__(cluster_resolver, communication_options) + + @classmethod + def _from_local_devices( + cls, + devices, + communication=collective_util.CommunicationImplementation.AUTO): + """A convenience method to create an object with a list of devices.""" + obj = cls(communication) + obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access + return obj + + +_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ + + @tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): __doc__ = CollectiveAllReduceStrategy.__doc__ - def __init__( - self, - communication=cross_device_ops_lib.CollectiveCommunication.AUTO, - cluster_resolver=None): + # The starting number for collective keys. This should only be set in tests. + _collective_key_base = 0 + + def __init__(self, + communication=collective_util.CommunicationImplementation.AUTO, + cluster_resolver=None): """Initializes the object.""" + communication_options = collective_util.Options( + implementation=communication) super(CollectiveAllReduceStrategyV1, self).__init__( CollectiveAllReduceExtended( self, - communication=communication, - cluster_resolver=cluster_resolver)) + cluster_resolver=cluster_resolver, + communication_options=communication_options)) distribute_lib.distribution_strategy_gauge.get_cell("V1").set( "MultiWorkerMirroredStrategy") # pylint: disable=protected-access @@ -195,17 +309,21 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): _check_health_initial_timeout = 0 # Times to retry before considering the peer is down. _check_health_retry_limit = 3 + # Timeout in seconds the each check health. + _check_health_timeout = 10 - def __init__(self, - container_strategy, - communication, - cluster_resolver): + def __init__(self, container_strategy, cluster_resolver, + communication_options): + if not isinstance(communication_options, collective_util.Options): + raise ValueError("communication_options must be an instance of " + "tf.distribute.experimental.CommunicationOptions") self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() + if not isinstance(self._cluster_resolver, ClusterResolver): + raise ValueError("cluster_resolver must be an instance of " + "tf.distribute.cluster_resolver.ClusterResolver") distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) - assert isinstance( - communication, - cross_device_ops_lib.CollectiveCommunication) - self._communication = communication + self._communication_options = communication_options + self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access self._initialize_strategy(self._cluster_resolver) self._cfer_fn_cache = weakref.WeakKeyDictionary() self.experimental_enable_get_next_as_optional = True @@ -250,19 +368,17 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._worker_device = device_util.canonicalize("/device:CPU:0") self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) - self._collective_keys = cross_device_utils.CollectiveKeys() + self._collective_keys = cross_device_utils.CollectiveKeys( + group_key_start=1 + self._collective_key_base) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices), - collective_keys=self._collective_keys, - communication=self._communication) + collective_keys=self._collective_keys) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, - collective_keys=self._collective_keys, - communication=cross_device_ops_lib.CollectiveCommunication.RING, - ) + collective_keys=self._collective_keys) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) @@ -282,8 +398,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() - logging.info("Single-worker MultiWorkerMirroredStrategy with local_devices " - "= %r, communication = %s", local_devices, self._communication) + logging.info( + "Single-worker MultiWorkerMirroredStrategy with local_devices " + "= %r, communication = %s", local_devices, + self._communication_options.implementation) def _initialize_multi_worker(self, cluster_resolver): """Initializes the object for multi-worker training.""" @@ -366,19 +484,16 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): local_devices = (self._worker_device,) self._collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access + group_key_start=1 + self._collective_key_base) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, - collective_keys=self._collective_keys, - communication=self._communication) + collective_keys=self._collective_keys) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, - collective_keys=self._collective_keys, - communication=cross_device_ops_lib.CollectiveCommunication.RING, - ) + collective_keys=self._collective_keys) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) @@ -397,9 +512,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " - "communication = %s", cluster_spec.as_dict(), task_type, - task_id, self._num_workers, local_devices, - self._communication) + "communication = %s", cluster_spec.as_dict(), task_type, task_id, + self._num_workers, local_devices, + self._communication_options.implementation) def __del__(self): self._stop_check_health_thread() @@ -475,6 +590,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): return input_context def _experimental_distribute_dataset(self, dataset, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function`." + ) input_context = self._make_input_context() return input_lib.get_distributed_dataset( dataset, @@ -484,6 +606,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): input_context=input_context) def _distribute_datasets_from_function(self, dataset_fn, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + " `experimental_distribute_datasets_from_function` " + "of tf.distribute.MirroredStrategy") input_context = self._make_input_context() return input_lib.get_distributed_datasets_from_function( dataset_fn=dataset_fn, @@ -570,8 +699,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") if (not ops.executing_eagerly_outside_functions() and - self._communication == - cross_device_ops_lib.CollectiveCommunication.NCCL): + self._communication_options.implementation == + collective_util.CommunicationImplementation.NCCL): updated_config.experimental.collective_nccl = True if not self._cluster_spec: @@ -609,15 +738,14 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: return self._host_cross_device_ops - def _gather_to_implementation(self, value, destinations, axis, - experimental_hints): + def _gather_to_implementation(self, value, destinations, axis, options): return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access value, destinations=destinations, axis=axis, - experimental_hints=experimental_hints) + options=options) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): if (isinstance(value, values.Mirrored) and reduce_op == reduce_util.ReduceOp.MEAN): return value @@ -641,7 +769,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): reduce_op, value, destinations=destinations, - experimental_hints=experimental_hints) + options=self._communication_options.merge(options)) def _check_health(self): while True: @@ -654,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): while True: attempts += 1 try: - context.context().check_collective_ops_peer_health(peer) + context.context().check_collective_ops_peer_health( + peer, timeout_in_ms=self._check_health_timeout * 1000) # If check_collective_ops_peer_health doesn't raise an Exception, # the peer is healthy. break - except (errors.UnavailableError, - errors.FailedPreconditionError) as e: + except (errors.UnavailableError, errors.FailedPreconditionError, + errors.DeadlineExceededError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer # fails. Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. @@ -703,8 +832,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): reduce_util.ReduceOp.SUM, dummy_value, dummy_value, - experimental_hints=collective_util.Hints( - timeout_seconds=self._check_health_initial_timeout)) + options=collective_util.Options( + timeout_seconds=self._check_health_initial_timeout, + implementation=collective_util.CommunicationImplementation.RING)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: @@ -730,8 +860,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): logging.info("check health thread stopped") def _warn_nccl_no_gpu(self): - if ((self._communication == - cross_device_ops_lib.CollectiveCommunication.NCCL) and + if ((self._communication_options.implementation == + collective_util.CommunicationImplementation.NCCL) and self._num_gpus_per_worker == 0): logging.warning("Enabled NCCL communication but no GPUs detected/" "specified.") diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index 305008f6cb8..39d2b432a25 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -62,6 +62,8 @@ CollectiveAllReduceStrategy = ( collective_all_reduce_strategy.CollectiveAllReduceStrategy) CollectiveAllReduceExtended = ( collective_all_reduce_strategy.CollectiveAllReduceExtended) +_CollectiveAllReduceStrategyExperimental = ( + collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental) def create_test_objects(cluster_spec=None, @@ -610,5 +612,27 @@ class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase): strategy.extended._num_workers, results[1].numpy()) +class ExperimentalCompatibilityTest(test.TestCase): + + def testIsInstance(self): + # It's not uncommon for people to special case MultiWorkerMirroredStrategy, + # so we need to make sure isinstance check works for combinations between + # the experimental and non-experimental endpoints. + strategy = CollectiveAllReduceStrategy() + experimental_strategy = _CollectiveAllReduceStrategyExperimental() + self.assertIsInstance(strategy, CollectiveAllReduceStrategy) + self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental) + self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy) + self.assertIsInstance(experimental_strategy, + _CollectiveAllReduceStrategyExperimental) + + def testName(self): + # Estimator checks the __name__ to special case MultiWorkerMirroredStrategy. + self.assertEqual(CollectiveAllReduceStrategy.__name__, + 'CollectiveAllReduceStrategy') + self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__, + 'CollectiveAllReduceStrategy') + + if __name__ == '__main__': test_util.main() diff --git a/tensorflow/python/distribute/collective_util.py b/tensorflow/python/distribute/collective_util.py index 0d9c404e520..0d4554480b5 100644 --- a/tensorflow/python/distribute/collective_util.py +++ b/tensorflow/python/distribute/collective_util.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,9 +19,145 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy +import enum + +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +# TODO(b/170340570): print deprecation warning for CollectiveCommunication. +@tf_export("distribute.experimental.CommunicationImplementation", + "distribute.experimental.CollectiveCommunication") +class CommunicationImplementation(enum.Enum): + """Cross device communication implementation. + + Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is + deprecated and will be removed in a future version. Use + `tf.distribute.experimental.CommunicationImplementation` instead. + + * `AUTO`: Automatically chosen by Tensorflow. + * `RING`: TensorFlow's ring algorithms for all-reduce and + all-gather. + * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on + GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING. + """ + AUTO = "AUTO" + RING = "RING" + NCCL = "NCCL" + # TODO(ayushd): add ncclAllGather implementation. + + +CollectiveCommunication = CommunicationImplementation + + +@tf_export("distribute.experimental.CommunicationOptions") +class _OptionsExported(object): + """Options for cross device communications like All-reduce. + + This can be passed to methods like + `tf.distribute.get_replica_context().all_reduce()` to optimize collective + operation performance. Note that these are only hints, which may or may not + change the actual behavior. Some options only apply to certain strategy and + are ignored by others. + + One common optimization is to break gradients all-reduce into multiple packs + so that weight updates can overlap with gradient all-reduce. + + Examples: + + ```python + options = tf.distribute.experimental.CommunicationOptions( + bytes_per_pack=50 * 1024 * 1024, + timeout_seconds=120, + implementation=tf.distribute.experimental.CommunicationImplementation.NCCL + ) + grads = tf.distribute.get_replica_context().all_reduce( + 'sum', grads, options=options) + optimizer.apply_gradients(zip(grads, vars), + experimental_aggregate_gradients=False) + ``` + + """ + + def __new__(cls, *args, **kwargs): + return Options.__new__(Options, *args, **kwargs) + + def __init__(self, + bytes_per_pack=0, + timeout_seconds=None, + implementation=CommunicationImplementation.AUTO): + """Creates a CollectiveHints. + + Args: + bytes_per_pack: a non-negative integer. Breaks collective operations into + packs of certain size. If it's zero, the value is determined + automatically. This only applies to all-reduce with + `MultiWorkerMirroredStrategy` currently. + timeout_seconds: a float or None, timeout in seconds. If not None, the + collective raises `tf.errors.DeadlineExceededError` if it takes longer + than this timeout. Zero disables timeout. This can be useful when + debugging hanging issues. This should only be used for debugging since + it creates a new thread for each collective, i.e. an overhead of + `timeout_seconds * num_collectives_per_second` more threads. This only + works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`. + implementation: a + `tf.distribute.experimental.CommunicationImplementation`. This is a hint + on the preferred communication implementation. Possible values include + `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU, + but doesn't work for CPU. This only works for + `tf.distribute.experimental.MultiWorkerMirroredStrategy`. + + Raises: + ValueError: When arguments have invalid value. + """ + pass + + +class Options(object): + """Implementation of OptionsInterface.""" + + def __init__(self, + bytes_per_pack=0, + timeout_seconds=None, + implementation=CommunicationImplementation.AUTO): + if bytes_per_pack < 0: + raise ValueError("bytes_per_pack must be non-negative") + if isinstance(implementation, str): + implementation = CommunicationImplementation(implementation.upper()) + if not isinstance(implementation, CommunicationImplementation): + raise ValueError("implementation should be a " + "tf.distribute.experimental.CommunicationImplementation") + self.bytes_per_pack = bytes_per_pack + self.timeout_seconds = timeout_seconds + self.implementation = implementation + + __init__.__doc__ = _OptionsExported.__init__.__doc__ + + def merge(self, options): + """Merges with another options and returns a new one. + + Values specified in the `options` takes precedence if they're not the + default. + + Args: + options: a `tf.distribute.experimental.CollectiveCommunication`. + + Returns: + A new `tf.distribute.experimental.CollectiveCommunication`. + """ + merged = copy.deepcopy(self) + if options is None: + return merged + if options.bytes_per_pack != 0: + merged.bytes_per_pack = options.bytes_per_pack + if options.timeout_seconds is not None: + merged.timeout_seconds = options.timeout_seconds + if options.implementation != CommunicationImplementation.AUTO: + merged.implementation = options.implementation + return merged + + @tf_export("distribute.experimental.CollectiveHints") class Hints(object): """Hints for collective operations like AllReduce. @@ -61,6 +198,12 @@ class Hints(object): """ + @deprecation.deprecated( + None, "use distribute.experimental.CommunicationOptions instead") + def __new__(cls, bytes_per_pack=0, timeout_seconds=None): + return Options( + bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds) + def __init__(self, bytes_per_pack=0, timeout_seconds=None): """Creates a CollectiveHints. @@ -80,7 +223,4 @@ class Hints(object): Raises: ValueError: When arguments have invalid value. """ - if bytes_per_pack < 0: - raise ValueError("bytes_per_pack must be non-negative") - self.bytes_per_pack = bytes_per_pack - self.timeout_seconds = timeout_seconds + pass diff --git a/tensorflow/python/distribute/collective_util_test.py b/tensorflow/python/distribute/collective_util_test.py new file mode 100644 index 00000000000..e75d520979b --- /dev/null +++ b/tensorflow/python/distribute/collective_util_test.py @@ -0,0 +1,41 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for utilities for collectives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import collective_util +from tensorflow.python.eager import test + + +class OptionsTest(test.TestCase): + + def testCreateOptionsViaExportedAPI(self): + options = collective_util._OptionsExported() + self.assertIsInstance(options, collective_util.Options) + + def testCreateOptionsViaHints(self): + with self.assertLogs() as cm: + options = collective_util.Hints(50, 1) + self.assertTrue(any("is deprecated" in msg for msg in cm.output)) + self.assertIsInstance(options, collective_util.Options) + self.assertEqual(options.bytes_per_pack, 50) + self.assertEqual(options.timeout_seconds, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index c9d3d7d9a9a..861a2f0490e 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -101,7 +101,7 @@ class ClusterParameters(combinations_lib.ParameterModifier): else: has_chief = kwargs.get("has_chief", False) num_workers = kwargs.get("num_workers", 1) - runner = None + runner = kwargs.get("runner", None) # Always set cluster parameters if they're requested. So that generate() # works when there's no startegy in the combinations. @@ -281,23 +281,24 @@ class NamedDistribution(object): self.use_cloud_tpu = use_cloud_tpu self.has_chief = has_chief self.num_workers = num_workers + self.use_pool_runner = use_pool_runner self.no_xla = no_xla self._runner = None - if _num_total_workers(self.has_chief, self.num_workers) > 1: - cluster_spec = multi_worker_test_base.create_cluster_spec( - has_chief=has_chief, - num_workers=num_workers, - num_ps=0, - has_eval=False) - if use_pool_runner: - # Need to create the strategy in the initializer so that collectives are - # configured before eager context initialization. - self._runner = multi_process_runner.MultiProcessPoolRunner( - cluster_spec, initializer=self._distribution_fn) - @property def runner(self): + if not self._runner: + if (_num_total_workers(self.has_chief, self.num_workers) > 1 and + self.use_pool_runner): + # Need to create the strategy in the initializer so that collectives are + # configured before eager context initialization. + cluster_spec = multi_worker_test_base.create_cluster_spec( + has_chief=self.has_chief, + num_workers=self.num_workers, + num_ps=0, + has_eval=False) + self._runner = multi_process_runner.MultiProcessPoolRunner( + cluster_spec, initializer=self._distribution_fn) return self._runner @property diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index 02ddcbef632..e9897a45805 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -96,6 +96,13 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # set to the main process. self.assertNotEqual(os.getenv("TF_CONFIG"), "") + def test_runner_creation(self): + cmb = combinations.NamedDistribution( + "Strategy1", lambda: None, has_chief=True, num_workers=2, + use_pool_runner=True) + self.assertIsNone(cmb._runner) + self.assertIsNotNone(cmb.runner) + # unittest.expectedFailure doesn't work with parameterized test methods, so we # have to decorate the class instead. diff --git a/tensorflow/python/distribute/client/BUILD b/tensorflow/python/distribute/coordinator/BUILD similarity index 66% rename from tensorflow/python/distribute/client/BUILD rename to tensorflow/python/distribute/coordinator/BUILD index f6a4fd7f17c..4601f194be2 100644 --- a/tensorflow/python/distribute/client/BUILD +++ b/tensorflow/python/distribute/coordinator/BUILD @@ -8,8 +8,8 @@ package( exports_files(["LICENSE"]) py_library( - name = "client", - srcs = ["client.py"], + name = "cluster_coordinator", + srcs = ["cluster_coordinator.py"], srcs_version = "PY2AND3", deps = [ ":metric_utils", @@ -34,37 +34,28 @@ py_library( ) tf_py_test( - name = "client_test", + name = "cluster_coordinator_test", size = "small", - srcs = ["client_test.py"], + srcs = ["cluster_coordinator_test.py"], python_version = "PY3", - shard_count = 12, - tfrt_enabled = True, - deps = [ - ":client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training_lib", - "//tensorflow/python:util", - "//tensorflow/python/eager:def_function", + shard_count = 50, + tags = [ + "no_oss", # TODO(b/162119374) + "notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards ], -) - -tf_py_test( - name = "parameter_server_client_test", - srcs = ["parameter_server_client_test.py"], - python_version = "PY3", - shard_count = 14, - tags = ["no_oss"], # TODO(b/162119374) deps = [ - ":client", + ":cluster_coordinator", "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:tensor_spec", + "//tensorflow/python:training_lib", "//tensorflow/python:training_server_lib", + "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:multi_worker_test_base", @@ -75,12 +66,16 @@ tf_py_test( ) tf_py_test( - name = "client_mpr_test", - srcs = ["client_mpr_test.py"], + name = "cluster_coordinator_mpr_test", + srcs = ["cluster_coordinator_mpr_test.py"], python_version = "PY3", - shard_count = 2, - tags = ["no_oss"], # TODO(b/162119374) + shard_count = 5, + tags = [ + "no_oss_py38", # TODO(b/171435331) + "notsan", # TODO(b/171406091) + ], deps = [ + ":cluster_coordinator", ":remote_eager_lib", ":utils", "//tensorflow/python:dtypes", @@ -90,13 +85,45 @@ tf_py_test( "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:sharded_variable", - "//tensorflow/python/distribute/client", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", ], ) +tf_py_test( + name = "fault_tolerance_test", + srcs = ["fault_tolerance_test.py"], + python_version = "PY3", + shard_count = 9, + tags = [ + "no_oss", # TODO(b/168772720) + "noasan", # Multi-process runner does not work with test sanitizers + "notsan", # Multi-process runner does not work with test sanitizers + ], + deps = [ + ":cluster_coordinator", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/distribute:multi_process_runner", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:parameter_server_strategy_v2", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:test", + "//tensorflow/python/training:training_lib", + ], +) + py_library( name = "metric_utils", srcs = ["metric_utils.py"], @@ -111,7 +138,7 @@ tf_py_test( srcs = ["metric_utils_test.py"], python_version = "PY3", deps = [ - ":client", + ":cluster_coordinator", ":metric_utils", "//tensorflow/python:training_server_lib", "//tensorflow/python/distribute:multi_worker_test_base", diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py similarity index 67% rename from tensorflow/python/distribute/client/client.py rename to tensorflow/python/distribute/coordinator/cluster_coordinator.py index 6eabbfa219a..651fd80b400 100644 --- a/tensorflow/python/distribute/client/client.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Module for `Client` and relevant cluster-worker related library. +"""Module for `ClusterCoordinator` and relevant cluster-worker related library. This is currently under development and the API is subject to change. """ @@ -29,13 +29,14 @@ import os import re import sys import threading +import time import weakref from six.moves import queue from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.client import metric_utils +from tensorflow.python.distribute.coordinator import metric_utils from tensorflow.python.eager import cancellation from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -44,10 +45,10 @@ from tensorflow.python.eager import function as tf_function from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export # Maximum time for failed worker to come back is 1 hour _WORKER_MAXIMUM_RECOVERY_SEC = 3600 @@ -56,9 +57,10 @@ _WORKER_MAXIMUM_RECOVERY_SEC = 3600 # When the maximum queue size is reached, further schedule calls will become # blocking until some previously queued closures are executed on workers. # Note that using an "infinite" queue size can take a non-trivial portion of -# memory, and even lead to client OOM. Modify the size to a smaller value for -# client with constrained memory resource (only recommended for advanced users). -# Also used in unit tests to ensure the correctness when the queue is full. +# memory, and even lead to coordinator OOM. Modify the size to a smaller value +# for coordinator with constrained memory resource (only recommended for +# advanced users). Also used in unit tests to ensure the correctness when the +# queue is full. _CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 # RPC error message from PS @@ -99,22 +101,81 @@ class _RemoteValueStatus(enum.Enum): READY = "READY" +@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[]) class RemoteValue(object): - """An asynchronously available value of a remotely executed function. + """An asynchronously available value of a scheduled function. - `RemoteValue` class is used as the return value of `Client.schedule()` where - the underlying concrete value comes at a later time once the function has been - remotely executed. `RemoteValue` can be used as an input to a subsequent - function scheduled with `Client.schedule()`. + This class is used as the return value of + `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where + the underlying value becomes available at a later time once the function has + been executed. - Note: this class is not thread-safe. + Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to + a subsequent function scheduled with + `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is + currently not supported. + + Example: + + ```python + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver=...) + coordinator = ( + tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)) + + with strategy.scope(): + v1 = tf.Variable(initial_value=0.0) + v2 = tf.Variable(initial_value=1.0) + + @tf.function + def worker_fn(): + v1.assign_add(0.1) + v2.assign_sub(0.2) + return v1.read_value() / v2.read_value() + + result = coordinator.schedule(worker_fn) + # Note that `fetch()` gives the actual result instead of a `tf.Tensor`. + assert result.fetch() == 0.125 + + for _ in range(10): + # `worker_fn` will be run on arbitrary workers that are available. The + # `result` value will be available later. + result = coordinator.schedule(worker_fn) + ``` """ - def __init__(self, closure, type_spec): + def fetch(self): + """Wait for the result of `RemoteValue` to be ready and return the result. + + This makes the value concrete by copying the remote value to local. + + Returns: + The actual output of the `tf.function` associated with this `RemoteValue`, + previously by a + `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. + This can be a single value, or a structure of values, depending on the + output of the `tf.function`. + + Raises: + tf.errors.CancelledError: If the function that produces this `RemoteValue` + is aborted or cancelled due to failure. + """ + raise NotImplementedError("Must be implemented in subclasses.") + + +class RemoteValueImpl(RemoteValue): + """Implementation of `RemoteValue`.""" + + def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called + """Initializes a `RemoteValueImpl`. + + Args: + closure: The closure from which the `RemoteValue` is created. + type_spec: The type spec for this `RemoteValue` which is used to trace + functions that take this `RemoteValue` as input. + """ self._closure = closure - # The type spec for this `RemoteValue` which is used to trace functions that - # take this `RemoteValue` as input. - self._type_spec = func_graph.convert_structure_to_signature(type_spec) + self._type_spec = type_spec self._value = None self._error = None self._status_available_event = threading.Event() @@ -153,20 +214,7 @@ class RemoteValue(object): self._status_available_event.wait() return self._error - def _set_type_spec(self, type_spec): - self._type_spec = func_graph.convert_structure_to_signature(type_spec) - def fetch(self): - """Wait for the result of RemoteValue to be ready and return the result. - - Returns: - The remote value, as a numpy data type (if scalar) or ndarray. - - Raises: - tf.errors.CancelledError: If the function that produces this `RemoteValue` - is aborted or cancelled due to failure, and the user should handle and - reschedule. - """ self._status_available_event.wait() if self._status is _RemoteValueStatus.ABORTED: raise errors.CancelledError( @@ -176,11 +224,8 @@ class RemoteValue(object): if self._error is not None: raise self._error # pylint: disable=raising-bad-type else: - if isinstance(self._value, - (ops.Tensor, resource_variable_ops.BaseResourceVariable)): - return self._value.numpy() - else: - return self._value + return nest.map_structure( + lambda x: x.numpy() if hasattr(x, "numpy") else x, self._value) class InputError(Exception): @@ -241,8 +286,23 @@ def _maybe_as_type_spec(val): return val +@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[]) class PerWorkerValues(object): - """Holds a list of per worker values.""" + """A container that holds a list of values, one value per worker. + + `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection + of values, where each of the value is located one worker respectively, and + upon being used as one of the `args` or `kwargs` of + `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the + value specific to a worker will be passed into the function being executed at + that particular worker. + + Currently, the only supported path to create an object of + `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling + `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned + distributed dataset instance. The mechanism to create a custom + `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported. + """ def __init__(self, values): self._values = tuple(values) @@ -262,9 +322,10 @@ def _disallow_remote_value_as_input(structured): def _raise_if_remote_value(x): if isinstance(x, RemoteValue): - raise ValueError("RemoteValue cannot be used as an input to scheduled " - "function. Please file a feature request if you need " - "this feature.") + raise ValueError( + "`tf.distribute.experimental.coordinator.RemoteValue` used " + "as an input to scheduled function is not yet " + "supported.") nest.map_structure(_raise_if_remote_value, structured) @@ -274,8 +335,8 @@ class Closure(object): def __init__(self, function, cancellation_mgr, args=None, kwargs=None): if not callable(function): - raise ValueError("Function passed to `Client.schedule` must be a " - "callable object.") + raise ValueError("Function passed to `ClusterCoordinator.schedule` must " + "be a callable object.") self._args = args or () self._kwargs = kwargs or {} @@ -287,9 +348,9 @@ class Closure(object): replica_kwargs = _select_worker_slice(0, self._kwargs) # Note: no need to handle function registration failure since this kind of - # failure will not raise exceptions as designed in the runtime. The client - # has to rely on subsequent operations that raise to catch function - # registration failure. + # failure will not raise exceptions as designed in the runtime. The + # coordinator has to rely on subsequent operations that raise to catch + # function registration failure. # Record the function tracing overhead. Note that we pass in the tracing # count of the def_function.Function as a state tracker, so that metrics @@ -297,36 +358,32 @@ class Closure(object): # function cache lookups). with metric_utils.monitored_timer( "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access - concrete_function = function.get_concrete_function( + self._concrete_function = function.get_concrete_function( *nest.map_structure(_maybe_as_type_spec, replica_args), **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) - self._function = cancellation_mgr.get_cancelable_function( - concrete_function) - self._output_remote_values = nest.map_structure( - lambda x: RemoteValue(self, x), concrete_function.structured_outputs) elif isinstance(function, tf_function.ConcreteFunction): - self._function = cancellation_mgr.get_cancelable_function(function) - self._output_remote_values = nest.map_structure( - lambda x: RemoteValue(self, x), function.structured_outputs) + self._concrete_function = function + + if hasattr(self, "_concrete_function"): + # If we have a concrete function, we get to retrieve the output type spec + # via the structured_output. + output_type_spec = func_graph.convert_structure_to_signature( + self._concrete_function.structured_outputs) + self._function = cancellation_mgr.get_cancelable_function( + self._concrete_function) else: - # Regular python functions. + # Otherwise (i.e. what is passed in is a regular python function), we have + # no such information. + output_type_spec = None self._function = function - # TODO(yuefengz): maybe we should trace python functions if their inputs - # are Python primitives, tensors and composite tensors. - self._output_remote_values = RemoteValue(self, None) - def _fetch_output_remote_values(self): - """Temporary method used to sync the scheduler.""" - # It will do nothing if there is no return value. - nest.map_structure(lambda x: x.fetch(), self._output_remote_values) # pylint: disable=protected-access + self.output_remote_value = RemoteValueImpl(self, output_type_spec) - def _set_output_remote_values_cancelled(self): - nest.map_structure( - lambda x: x._set_error( # pylint: disable=protected-access,g-long-lambda - errors.CancelledError( - None, None, "The corresponding function is " - "cancelled. Please reschedule the function.")), - self._output_remote_values) # pylint: disable=protected-access + def mark_cancelled(self): + self.output_remote_value._set_error( # pylint: disable=protected-access + errors.CancelledError( + None, None, "The corresponding function is " + "cancelled. Please reschedule the function.")) def execute_on(self, worker): """Executes the closure on the given worker. @@ -343,8 +400,7 @@ class Closure(object): if e: if not isinstance(e, InputError): e = InputError(e) - for remote_value in nest.flatten(self._output_remote_values): - remote_value._set_error(e) # pylint: disable=protected-access + self.output_remote_value._set_error(e) # pylint: disable=protected-access return with ops.device(worker.device_name): @@ -353,9 +409,7 @@ class Closure(object): output_value = self._function( *nest.map_structure(_maybe_get_remote_value, replica_args), **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) - for remote_value, value in zip( - nest.flatten(self._output_remote_values), nest.flatten(output_value)): - remote_value._set_value(value) # pylint: disable=protected-access + self.output_remote_value._set_value(output_value) # pylint: disable=protected-access class _CoordinatedClosureQueue(object): @@ -394,7 +448,7 @@ class _CoordinatedClosureQueue(object): if _CLOSURE_QUEUE_MAX_SIZE <= 0: logging.warning( - "In a `Client`, creating an infinite closure queue can " + "In a `ClusterCoordinator`, creating an infinite closure queue can " "consume a significant amount of memory and even lead to OOM.") self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) self._error = None @@ -424,16 +478,16 @@ class _CoordinatedClosureQueue(object): try: closure = self._queue.get(block=False) self._queue_free_slot_condition.notify() - closure._set_output_remote_values_cancelled() # pylint: disable=protected-access + closure.mark_cancelled() except queue.Empty: break # The cancellation manager cannot be reused once cancelled. After all # closures (queued or inflight) are cleaned up, recreate the cancellation # manager with clean state. - # Note on thread-safety: this is triggered when one of theses client APIs - # are called: `schedule`, `wait`, and `done`. At the same time, no new - # closures can be constructed (which reads the _cancellation_mgr to get - # cancellable functions). + # Note on thread-safety: this is triggered when one of theses + # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the + # same time, no new closures can be constructed (which reads the + # _cancellation_mgr to get cancellable functions). self._cancellation_mgr = cancellation.CancellationManager() def _raise_if_error(self): @@ -496,7 +550,7 @@ class _CoordinatedClosureQueue(object): if self._inflight_closure_count < 1: raise AssertionError("There is no inflight closures to put_back.") if self._error: - closure._set_output_remote_values_cancelled() # pylint: disable=protected-access + closure.mark_cancelled() else: self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) self._queue.put(closure, block=False) @@ -690,7 +744,7 @@ class Worker(object): closure.execute_on(self) # TODO(yuefengz): we don't have to materialize results every step. with metric_utils.monitored_timer("remote_value_fetch"): - closure._fetch_output_remote_values() # pylint: disable=protected-access + closure.output_remote_value.fetch() self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access except Exception as e: # pylint: disable=broad-except # Avoid logging the derived cancellation error @@ -698,12 +752,28 @@ class Worker(object): logging.error( "/job:worker/task:%d encountered the following error when " "processing closure: %r:%s", self.worker_index, e, e) - nest.map_structure( - lambda x: x._set_error(e), # pylint: disable=protected-access - closure._output_remote_values) # pylint: disable=protected-access + closure.output_remote_value._set_error(e) # pylint: disable=protected-access self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access + def _maybe_delay(self): + """Delay if corresponding env vars are set.""" + # If the following two env vars variables are set. Scheduling for workers + # will start in a staggered manner. Worker i will wait for + # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding + # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`. + delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0")) + delay_cap = int( + os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0")) + if delay_cap: + delay_secs = min(delay_secs * self.worker_index, delay_cap) + if delay_secs > 0: + logging.info("Worker %d sleeping for %d seconds before running function", + self.worker_index, delay_secs) + time.sleep(delay_secs) + def _process_queue(self): + """Function running in a thread to process closure queues.""" + self._maybe_delay() while True: closure = self._cluster._closure_queue.get() # pylint: disable=protected-access self._process_closure(closure) @@ -730,7 +800,7 @@ class Worker(object): self._cluster._closure_queue._cancellation_mgr, # pylint: disable=protected-access args=args, kwargs=kwargs) - resource_remote_value = closure._output_remote_values # pylint: disable=protected-access + resource_remote_value = closure.output_remote_value self._register_resource(resource_remote_value) # The following is a short-term solution to lazily create resources in @@ -742,8 +812,8 @@ class Worker(object): def _register_resource(self, resource_remote_value): if not isinstance(resource_remote_value, RemoteValue): - raise ValueError( - "Resource being registered is not of type `RemoteValue`.") + raise ValueError("Resource being registered is not of type " + "`tf.distribute.experimental.coordinator.RemoteValue`.") self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) @@ -753,7 +823,7 @@ class Cluster(object): We assume all function errors are fatal and based on this assumption our error reporting logic is: 1) Both `schedule` and `join` can raise a non-retryable error which is the - first error seen by the client from any previously scheduled functions. + first error seen by the coordinator from any previously scheduled functions. 2) When an error is raised, there is no guarantee on how many previously scheduled functions have been executed; functions that have not been executed will be thrown away and marked as cancelled. @@ -775,17 +845,17 @@ class Cluster(object): # Ignore PS failures reported by workers due to transient connection errors. # Transient connectivity issues between workers and PS are relayed by the - # workers to the client, leading the client to believe that there are PS - # failures. The difference between transient vs. permanent PS failure is the - # number of reports from the workers. When this env var is set to a positive - # integer K, the client ignores up to K reports of a failed PS task. I.e., - # only when there are more than K trials of executing closures fail due to - # errors from the same PS instance do we consider the PS instance encounters - # a failure. + # workers to the coordinator, leading the coordinator to believe that there + # are PS failures. The difference between transient vs. permanent PS failure + # is the number of reports from the workers. When this env var is set to a + # positive integer K, the coordinator ignores up to K reports of a failed PS + # task, i.e., only when there are more than K trials of executing closures + # fail due to errors from the same PS instance do we consider the PS + # instance encounters a failure. # TODO(b/164279603): Remove this workaround when the underlying connectivity # issue in gRPC server is resolved. - self._transient_ps_failures_threshold = int(os.environ.get( - "TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3)) + self._transient_ps_failures_threshold = int( + os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3)) self._potential_ps_failures_lock = threading.Lock() self._potential_ps_failures_count = [0] * self._num_ps @@ -825,7 +895,7 @@ class Cluster(object): kwargs: Keyword arguments for `fn`. Returns: - A structure of `RemoteValue` object. + A `RemoteValue` object. """ closure = Closure( function, @@ -833,7 +903,7 @@ class Cluster(object): args=args, kwargs=kwargs) self._closure_queue.put(closure) - return closure._output_remote_values # pylint: disable=protected-access + return closure.output_remote_value def join(self): """Blocks until all scheduled functions are executed.""" @@ -844,73 +914,114 @@ class Cluster(object): return self._closure_queue.done() -class ParameterServerFailureError(Exception): - """An error representing at least one parameter server is interrupted.""" - pass +@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[]) +class ClusterCoordinator(object): + """An object to schedule and coordinate remote function execution. + This class is used to create fault-tolerant resources and dispatch functions + to remote TensorFlow servers. -class Client(object): - """An object to schedule and orchestrate remote function execution. + Currently, this class is not supported to be used in a standalone manner. It + should be used in conjunction with a `tf.distribute` strategy that is designed + to work with it. The `ClusterCoordinator` class currently only works + `tf.distribute.experimental.ParameterServerStrategy`. - A `Client` object represents a program used to create dataset, schedule - functions to be executed, and fetch the results of the functions. + __The `schedule`/`join` APIs__ - Currently, `Client` is not supported to be used in a standalone manner. - It should be used in conjunction with `ParameterServerStrategyV2`. + The most important APIs provided by this class is the `schedule`/`join` pair. + The `schedule` API is non-blocking in that it queues a `tf.function` and + returns a `RemoteValue` immediately. The queued functions will be dispatched + to remote workers in background threads and their `RemoteValue`s will be + filled asynchronously. Since `schedule` doesn’t require worker assignment, the + `tf.function` passed in can be executed on any available worker. If the worker + it is executed on becomes unavailable before its completion, it will be + migrated to another worker. Because of this fact and function execution is not + atomic, a function may be executed more than once. + + __Handling Task Failure__ + + This class when used with + `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in + fault tolerance for worker failures. That is, when some workers are not + available for any reason to be reached from the coordinator, the training + progress continues to be made with the remaining workers. Upon recovery of a + failed worker, it will be added for function execution after datasets created + by `create_per_worker_dataset` are re-built on it. + + When a parameter server the coordinator fails, a + `tf.errors.UnavailableError` is raised by `schedule`, `join` or `done`. In + this case, in addition to bringing back the failed parameter server, users + should restart the coordinator to so that it reconnects to the parameter + server, re-creates the variables and loads checkpoints. If the coordinator + fails, users need to bring it back as well. The program will automatically + connect to the parameter servers and workers, and continue the progress from a + checkpoint. + + It is thus essential that in user's program, a checkpoint file is periodically + saved, and restored at the start of the program. If an + `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a + checkpoiont, its `iterations` property roughly indicates the number of steps + that have been made. This can be used to decide how many epochs and steps are + needed before the training completion. + + See `tf.distribute.experimental.ParameterServerStrategy` docstring for an + example usage of this API. This is currently under development, and the API as well as implementation - is subject to changes. + are subject to changes. """ def __init__(self, strategy): - """Initialization of a `Client` instance. - - This connects the client to remote workers and parameter servers, through - a `tf.config.experimental_connect_to_cluster` call. + """Initialization of a `ClusterCoordinator` instance. Args: - strategy: a `tf.distribute.Strategy` object. Currently, only - `ParameterServerStrategyV2` is supported. + strategy: a supported `tf.distribute.Strategy` object. Currently, only + `tf.distribute.experimental.ParameterServerStrategy` is supported. Raises: ValueError: if the strategy being used is not supported. """ if not isinstance(strategy, parameter_server_strategy_v2.ParameterServerStrategyV2): - raise ValueError("Only `ParameterServerStrategyV2` is supported in " - "`Client` currently.") + raise ValueError( + "Only `tf.distribute.experimental.ParameterServerStrategy` " + "is supported to work with " + "`tf.distribute.experimental.coordinator.ClusterCoordinator` " + "currently.") self._strategy = strategy self.cluster = Cluster(strategy) @property def strategy(self): + """Returns the `Strategy` associated with the `ClusterCoordinator`.""" return self._strategy def schedule(self, fn, args=None, kwargs=None): - """Schedules `fn` to be dispatched to a worker for execution asynchronously. + """Schedules `fn` to be dispatched to a worker for asynchronous execution. - When calling `schedule` with a function `fn`, `fn` will be executed on a - remote worker at some later time. The process is asynchronous, meaning - `schedule` returns immediately, possibly without having the result ready - yet. `schedule` returns a structure of `RemoteValue` object, which wraps the - output of the function. Call `fetch()` on `RemoteValue` to wait for the - function execution to finish and retrieve its output from the remote worker. + This method is non-blocking in that it queues the `fn` which will be + executed later and returns a + `tf.distribute.experimental.coordinator.RemoteValue` object immediately. + `fetch` can be called on the it to wait for the function execution to finish + and retrieve its output from a remote worker. On the other hand, call + `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for + all scheduled functions to finish. `schedule` guarantees that `fn` will be executed on a worker at least once; it could be more than once if its corresponding worker fails in the middle of its execution. Note that since worker can fail at any point when executing the function, it is possible that the function is partially - executed, but `Client` guarantees that in those events, the function will - eventually be fully executed, possibly on a different worker that is - available. + executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator` + guarantees that in those events, the function will eventually be executed on + any worker that is available. - If any previously scheduled function raises an error, `schedule` will fail - by raising any one of those errors, and clear the errors collected so far. - There are two implications when this happens: 1) user should call `schedule` - with `fn` again to re-schedule, and 2) some of the previously scheduled - functions may have not been executed. User can call `fetch` on the returned - `RemoteValue` to inspect if they have executed, failed, or cancelled, and - reschedule the corresponding function if needed. + If any previously scheduled function raises an error, `schedule` will raise + any one of those errors, and clear the errors collected so far. What happens + here, some of the previously scheduled functions may have not been executed. + User can call `fetch` on the returned + `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have + executed, failed, or cancelled, and reschedule the corresponding function if + needed. When `schedule` raises, it guarantees that there is no function that is still being executed. @@ -919,12 +1030,13 @@ class Client(object): execution, or priority of the workers. `args` and `kwargs` are the arguments passed into `fn`, when `fn` is - executed on a worker. They can be `PerWorkerValues`, which is a collection - of values, each of which represents a component specific to a worker; in - this case, the argument will be substituted with the corresponding component - on the target worker. Arguments that are not `PerWorkerValues` will be - passed into `fn` as-is. Currently, `RemoteValue` is not supported to be - input `args` or `kwargs`. + executed on a worker. They can be + `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case, + the argument will be substituted with the corresponding component on the + target worker. Arguments that are not + `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into + `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue` + is not supported to be input `args` or `kwargs`. Args: fn: A `tf.function`; the function to be dispatched to a worker for @@ -933,16 +1045,17 @@ class Client(object): kwargs: Keyword arguments for `fn`. Returns: - A structure of `RemoteValue` object. + A `tf.distribute.experimental.coordinator.RemoteValue` object that + represents the output of the function scheduled. Raises: - Exception: one of the exceptions caught by the client by any previously - scheduled function since the last time an error was thrown or since - the beginning of the program. + Exception: one of the exceptions caught by the coordinator from any + previously scheduled function, since the last time an error was thrown + or since the beginning of the program. """ # Slot variables are usually created during function tracing time; thus # `schedule` needs to be called within the `strategy.scope()`. - with self.strategy.scope(), _translate_parameter_server_failure(): + with self.strategy.scope(): return self.cluster.schedule(fn, args=args, kwargs=kwargs) def join(self): @@ -951,21 +1064,20 @@ class Client(object): If any previously scheduled function raises an error, `join` will fail by raising any one of those errors, and clear the errors collected so far. If this happens, some of the previously scheduled functions may have not been - executed. Users can call `fetch` on the returned `RemoteValue` to inspect if - they have executed, failed, or cancelled. If some that have been cancelled - need to be rescheduled, users should call `schedule` with the function - again. + executed. Users can call `fetch` on the returned + `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have + executed, failed, or cancelled. If some that have been cancelled need to be + rescheduled, users should call `schedule` with the function again. When `join` returns or raises, it guarantees that there is no function that is still being executed. Raises: - Exception: one of the exceptions caught by the client by any previously - scheduled function since the last time an error was thrown or since - the beginning of the program. + Exception: one of the exceptions caught by the coordinator by any + previously scheduled function since the last time an error was thrown or + since the beginning of the program. """ - with _translate_parameter_server_failure(): - self.cluster.join() + self.cluster.join() def done(self): """Returns whether all the scheduled functions have finished execution. @@ -975,29 +1087,64 @@ class Client(object): When `done` returns True or raises, it guarantees that there is no function that is still being executed. + + Returns: + Whether all the scheduled functions have finished execution. + Raises: + Exception: one of the exceptions caught by the coordinator by any + previously scheduled function since the last time an error was thrown or + since the beginning of the program. """ return self.cluster.done() def create_per_worker_dataset(self, dataset_fn): """Create dataset on workers by calling `dataset_fn` on worker devices. - This creates the given dataset generated by dataset_fn on the workers + This creates the given dataset generated by dataset_fn on workers and returns an object that represents the collection of those individual - datasets. Calling `iter` on such collection of dataset returns a - `PerWorkerValues`, which is a collection of iterators, where the iterators - have been placed on respective workers. + datasets. Calling `iter` on such collection of datasets returns a + `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a + collection of iterators, where the iterators have been placed on respective + workers. - Calling `next` on this `PerWorkerValues` of iterators is currently - unsupported; it is meant to be passed as an argument into `Client.schedule`. - When the scheduled function is picked up and being executed by a worker, the + Calling `next` on a `PerWorkerValues` of iterator is unsupported. The + iterator is meant to be passed as an argument into + `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When + the scheduled function is about to be executed by a worker, the function will receive the individual iterator that corresponds to the - worker, and now `next` can be called on iterator to get the next (batch or - example) of data. + worker. The `next` method can be called on an iterator inside a + scheduled function when the iterator is an input of the function. - Dataset shuffling and repeating are usually needed in `dataset_fn`; however, - sharding is not recommended: some worker may not be available and those - examples may be skipped and not covered by other workers, if the dataset is - sharded. + Currently the `schedule` method assumes workers are all the same and thus + assumes the datasets on different workers are the same, except they may be + shuffled differently if they contain a `dataset.shuffle` operation and a + random seed is not set. Because of this, we also recommend the datasets to + be repeated indefinitely and schedule a finite number of steps instead of + relying on the `OutOfRangeError` from a dataset. + + + Example: + + ```python + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver=...) + coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( + strategy=strategy) + + @tf.function + def worker_fn(iterator): + return next(iterator) + + def per_worker_dataset_fn(): + return strategy.distribute_datasets_from_function( + lambda x: tf.data.from_tensor_slices([3] * 3) + + per_worker_dataset = coordinator.create_per_worker_dataset( + per_worker_dataset_fn) + per_worker_iter = iter(per_worker_dataset) + remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,)) + assert remote_value.fetch() == 3 + ``` Args: dataset_fn: The dataset function that returns a dataset. This is to be @@ -1006,7 +1153,8 @@ class Client(object): Returns: An object that represents the collection of those individual datasets. `iter` is expected to be called on this object that returns - a `PerWorkerValues` of the iterators (that are on the workers). + a `tf.distribute.experimental.coordinator.PerWorkerValues` of the + iterators (that are on the workers). """ input_workers = input_lib.InputWorkers([ (w.device_name, [w.device_name]) for w in self.cluster.workers @@ -1017,7 +1165,8 @@ class Client(object): def _create_per_worker_resources(self, fn, args=None, kwargs=None): """Synchronously create resources on the workers. - The resources are represented by `RemoteValue`s. + The resources are represented by + `tf.distribute.experimental.coordinator.RemoteValue`s. Args: fn: The function to be dispatched to all workers for execution @@ -1026,7 +1175,9 @@ class Client(object): kwargs: Keyword arguments for `fn`. Returns: - A `PerWorkerValues` object, which wraps a tuple of `RemoteValue` objects. + A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which + wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue` + objects. """ results = [] for w in self.cluster.workers: @@ -1034,21 +1185,52 @@ class Client(object): return PerWorkerValues(tuple(results)) def fetch(self, val): - """Blocking call to fetch results from `RemoteValue`s. + """Blocking call to fetch results from the remote values. - This returns the execution result of `RemoteValue`s; if not ready, - waiting for it while blocking the caller. + This is a wrapper around + `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a + `RemoteValue` structure; it returns the execution results of + `RemoteValue`s. If not ready, wait for them while blocking the caller. + + Example: + ```python + strategy = ... + coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( + strategy) + + def dataset_fn(): + return tf.data.Dataset.from_tensor_slices([1, 1, 1]) + + with strategy.scope(): + v = tf.Variable(initial_value=0) + + @tf.function + def worker_fn(iterator): + def replica_fn(x): + v.assign_add(x) + return v.read_value() + return strategy.run(replica_fn, args=(next(iterator),)) + + distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) + distributed_iterator = iter(distributed_dataset) + result = coordinator.schedule(worker_fn, args=(distributed_iterator,)) + assert coordinator.fetch(result) == 1 + ``` Args: val: The value to fetch the results from. If this is structure of - `RemoteValue`, `fetch()` will be called on the individual `RemoteValue` - to get the result. + `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be + called on the individual + `tf.distribute.experimental.coordinator.RemoteValue` to get the result. Returns: - If `val` is a `RemoteValue` or a structure of `RemoteValue`s, returns - the fetched `RemoteValue` value immediately if it's available, or blocks - the call until it's available, and returns the fetched `RemoteValue` - values with the same structure. If `val` is other types, return (`val`,). + If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a + structure of `tf.distribute.experimental.coordinator.RemoteValue`s, + return the fetched `tf.distribute.experimental.coordinator.RemoteValue` + values immediately if they are available, or block the call until they are + available, and return the fetched + `tf.distribute.experimental.coordinator.RemoteValue` values with the same + structure. If `val` is other types, return it as-is. """ def _maybe_fetch(val): @@ -1058,31 +1240,15 @@ class Client(object): return val # TODO(yuefengz): we should fetch values in a batch. - result = nest.map_structure(_maybe_fetch, val) - if not isinstance(result, tuple): - return (result,) - return result - - -# pylint: disable=missing-function-docstring -@contextlib.contextmanager -def _translate_parameter_server_failure(): - try: - yield - except Exception as e: # pylint: disable=broad-except - if _is_ps_failure(e): - raise ParameterServerFailureError(e) - else: - raise + return nest.map_structure(_maybe_fetch, val) # pylint: disable=missing-function-docstring @contextlib.contextmanager def handle_parameter_server_failure(): try: - with _translate_parameter_server_failure(): - yield - except ParameterServerFailureError as e: # pylint: disable=broad-except + yield + except errors.UnavailableError as e: # pylint: disable=broad-except restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE", None) if restart_exit_code is not None: @@ -1094,13 +1260,14 @@ def handle_parameter_server_failure(): class _PerWorkerDistributedDataset(object): """Represents worker-distributed datasets created from dataset function.""" - def __init__(self, dataset_fn, input_workers, client): + def __init__(self, dataset_fn, input_workers, coordinator): """Makes an iterable from datasets created by the given function. Args: dataset_fn: A function that returns a `Dataset`. input_workers: an `InputWorkers` object. - client: a `Client` object, used to create dataset resources. + coordinator: a `ClusterCoordinator` object, used to create dataset + resources. """ def disallow_variable_creation(next_creator, **kwargs): raise ValueError("Creating variables in `dataset_fn` is not allowed.") @@ -1113,7 +1280,7 @@ class _PerWorkerDistributedDataset(object): dataset_fn = def_function.function(dataset_fn).get_concrete_function() self._dataset_fn = dataset_fn self._input_workers = input_workers - self._client = client + self._coordinator = coordinator self._element_spec = None def __iter__(self): @@ -1131,13 +1298,13 @@ class _PerWorkerDistributedDataset(object): # If _PerWorkerDistributedDataset.__iter__ is called multiple # times, for the same object it should only create and register resource # once. Using object id to distinguish different iterator resources. - per_worker_iterator = self._client._create_per_worker_resources( + per_worker_iterator = self._coordinator._create_per_worker_resources( _create_per_worker_iterator) # Setting type_spec of each RemoteValue so that functions taking these # RemoteValues as inputs can be traced. for iterator_remote_value in per_worker_iterator._values: - iterator_remote_value._set_type_spec( + iterator_remote_value._type_spec = ( # pylint: disable=protected-access iterator_ops.IteratorSpec( self._dataset_fn.structured_outputs.element_spec)) return _PerWorkerDistributedIterator(per_worker_iterator._values) @@ -1150,7 +1317,7 @@ class _PerWorkerDistributedDataset(object): class _PerWorkerDistributedIterator(PerWorkerValues): - """Distributed iterator for `Client`.""" + """Distributed iterator for `ClusterCoordinator`.""" def __next__(self): return self.get_next() @@ -1169,10 +1336,8 @@ def _extract_failed_ps_instances(err_msg): def _is_ps_failure(error): """Whether the error is considered a parameter server failure.""" - if (_RPC_ERROR_FROM_PS in str(error) or - (isinstance(error, errors.InvalidArgumentError) and - "/job:ps" in str(error))): - return True + return (isinstance(error, errors.UnavailableError) and + _RPC_ERROR_FROM_PS in str(error)) def _is_worker_failure(error): @@ -1192,21 +1357,14 @@ def _is_worker_failure(error): # failure. In that case, gRPC allows channel (which is different from a # connection) to be reused for a replaced server listening to same address. if isinstance(error, errors.InvalidArgumentError): - if ("Unable to find a context_id" in str(error) or - "unknown device" in str(error) or + if ("unknown device" in str(error) or "Unable to find the relevant tensor remote_handle" in str(error)): # TODO(b/159961667): Fix "Unable to find the relevant tensor # remote_handle" part. return True - # TODO(b/162541228): The following 3 types of errors are very rare and only + # TODO(b/162541228): The following 2 types of errors are very rare and only # observed in large-scale testing. The types of errors should be reduced. - # This error could show up when copying function inputs from remote tasks. - if isinstance(error, errors.InternalError): - if ("Failed copying input tensor" in str(error) or - "Unable to find a context_id" in str(error)): - return True - # This could happen when the function registration fails. In the observed # cases this only happens to the dataset related functions. if isinstance(error, errors.NotFoundError): diff --git a/tensorflow/python/distribute/client/client_mpr_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py similarity index 52% rename from tensorflow/python/distribute/client/client_mpr_test.py rename to tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py index 802b23e87ec..8b3e95f1fea 100644 --- a/tensorflow/python/distribute/client/client_mpr_test.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_mpr_test.py @@ -13,31 +13,112 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`.""" +"""Multi-process runner tests for `ClusterCoordinator` with PSv2.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - +import os import time - from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.client import client as client_lib -from tensorflow.python.distribute.client import utils from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib +from tensorflow.python.distribute.coordinator import utils from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging -class ClientMprTest(test.TestCase): +class ClusterCoordinatorMprTest(test.TestCase): + + # TODO(b/168772720): Merge or remove the following task failure tests once + # MultiProcessCluster is made available in OSS. + def testStrategyRun_withWorkerFailures(self): + self._testStrategyRun("worker") + + def testStrategyRun_withPsFailures(self): + self._testStrategyRun("ps") + + def testStrategyRun_withoutFailures(self): + self._testStrategyRun(None) + + def _testStrategyRun(self, failure_task_type): + + def fn(functions_scheduled_event): + # TODO(b/170664373): This is needed for TF2 parameter server training in + # OSS. Remove this when resolved. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + cluster_resolver = TFConfigClusterResolver() + if cluster_resolver.task_type != "chief": + utils.start_server(cluster_resolver, "grpc") + strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + cluster_resolver) + ps_client = coordinator_lib.ClusterCoordinator(strategy) + + with strategy.scope(): + v = variables.Variable(initial_value=1) + + @def_function.function + def worker_fn(input_tensor): + + def replica_fn(input_tensor): + return input_tensor + v + + run_result = strategy.run(replica_fn, args=(input_tensor,)) + check_ops.assert_equal_v2(run_result, 4) + return run_result + + for i in range(5000): + if i % 500 == 0: + logging.info("Scheduling function-{}...".format(i)) + result = ps_client.schedule(worker_fn, args=(constant_op.constant(3),)) + functions_scheduled_event.set() + logging.info("Joining...") + ps_client.join() + logging.info("Finished joining.") + if result.fetch() != 4: + raise AssertionError("Unexpected RemoteValue result: {}".format( + result.fetch())) + logging.info("testStrategyRun succeeded") + + manager = multi_process_runner.manager() + functions_scheduled_event = manager.Event() + mpr = multi_process_runner.MultiProcessRunner( + fn, + multi_worker_test_base.create_cluster_spec( + has_chief=True, num_workers=1, num_ps=1, has_eval=False), + args=(functions_scheduled_event,), + rpc_layer="grpc", + return_output=True) + mpr.start() + + if failure_task_type is not None: + functions_scheduled_event.wait() + logging.info("Before interrupting {}-0.".format(failure_task_type)) + mpr.terminate(failure_task_type, 0) + + if failure_task_type == "ps": + with self.assertRaises(errors.UnavailableError): + mpr.join() + return + + time.sleep(10) + logging.info("Before restarting {}-0.".format(failure_task_type)) + mpr.start_single_process(task_type="worker", task_id=0) + + self.assertTrue( + any(["testStrategyRun succeeded" in msg for msg in mpr.join().stdout])) def testScheduleTranslatePSFailureError(self): self._test_translate_ps_failure_error(test_schedule=True) @@ -50,12 +131,16 @@ class ClientMprTest(test.TestCase): test_join=False): def fn(functions_scheduled_event, test_finished_event): + # TODO(b/170664373): This is needed for TF2 parameter server training in + # OSS. Remove this when resolved. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) - ps_client = client_lib.Client(strategy) + ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @@ -67,21 +152,20 @@ class ClientMprTest(test.TestCase): v.assign_add(1) # Keep the two workers occupied. - ps_client.schedule(worker_fn) - ps_client.schedule(worker_fn) + ps_coordinator.schedule(worker_fn) + ps_coordinator.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() - # Verified that join and schedule indeed raise - # ParameterServerFailureError. + # Verified that join and schedule indeed raise UnavailableError. try: if test_join: - ps_client.join() + ps_coordinator.join() if test_schedule: - while ps_client.cluster._closure_queue._error is None: + while ps_coordinator.cluster._closure_queue._error is None: time.sleep(1) - ps_client.schedule(worker_fn) - except client_lib.ParameterServerFailureError: + ps_coordinator.schedule(worker_fn) + except errors.UnavailableError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): @@ -91,7 +175,7 @@ class ClientMprTest(test.TestCase): # failure. worker_fn() except Exception as e: # pylint: disable=broad-except - if client_lib._is_ps_failure(e): + if coordinator_lib._is_ps_failure(e): if worker_id < 2: continue logging.info("_test_translate_ps_failure_error ends properly.") @@ -101,7 +185,7 @@ class ClientMprTest(test.TestCase): raise RuntimeError("Executing a function after PS fails, should " "result in a PS failure.") - raise RuntimeError("ParameterServerFailureError supposed to be raised.") + raise RuntimeError("UnavailableError supposed to be raised.") manager = multi_process_runner.manager() functions_scheduled_event = manager.Event() diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py new file mode 100644 index 00000000000..a8ab4300713 --- /dev/null +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py @@ -0,0 +1,928 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for coordinator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import os +import platform +import sys +import threading +import time + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import parameter_server_strategy_v2 +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib +from tensorflow.python.eager import cancellation +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator +from tensorflow.python.training.server_lib import ClusterSpec + + +class CoordinatedClosureQueueTest(test.TestCase): + + def testBasic(self): + queue = coordinator_lib._CoordinatedClosureQueue() + closure1 = self._create_closure(queue._cancellation_mgr) + queue.put(closure1) + self.assertIs(closure1, queue.get()) + self.assertFalse(queue.done()) + queue.put_back(closure1) + self.assertEqual(closure1, queue.get()) + queue.mark_finished() + self.assertTrue(queue.done()) + queue.wait() + + def testProcessAtLeaseOnce(self): + closure_queue = coordinator_lib._CoordinatedClosureQueue() + labels = ['A', 'B', 'C', 'D', 'E'] + processed_count = collections.defaultdict(int) + + coord = coordinator.Coordinator(clean_stop_exception_types=[]) + + def process_queue(): + with coord.stop_on_exception(): + has_been_put_back = False + while True: + closure = closure_queue.get(timeout=30) + if closure is None: + break + if not has_been_put_back: + has_been_put_back = True + closure_queue.put_back(closure) + continue + closure._function() + closure_queue.mark_finished() + + def get_func(label): + + def func(): + time.sleep(3) + processed_count[label] += 1 + + return func + + cm = cancellation.CancellationManager() + for label in labels: + closure_queue.put(coordinator_lib.Closure(get_func(label), cm)) + t1 = threading.Thread(target=process_queue, daemon=True) + t1.start() + t2 = threading.Thread(target=process_queue, daemon=True) + t2.start() + + # Make sure multiple wait() calls are fine. + closure_queue.wait() + closure_queue.wait() + closure_queue.wait() + closure_queue.wait() + + self.assertEqual(processed_count, collections.Counter(labels)) + + coord.join([t1, t2]) + + def testNotifyBeforeWait(self): + closure_queue = coordinator_lib._CoordinatedClosureQueue() + + def func(): + logging.info('func running') + + coord = coordinator.Coordinator(clean_stop_exception_types=[]) + + def process_queue(): + with coord.stop_on_exception(): + closure_queue.get() + closure_queue.mark_finished() + + closure_queue.put( + coordinator_lib.Closure(func, closure_queue._cancellation_mgr)) + t = threading.Thread(target=process_queue) + t.start() + coord.join([t]) + + # This test asserts that waiting at the time the function has been processed + # doesn't time out. + closure_queue.wait() + + def _assert_one_unblock_the_other(self, first_fn, second_fn): + """Asserts `second_fn` wouldn't return before `first_fn` is finished.""" + first_fn_done = threading.Event() + second_fn_done = threading.Event() + coord = coordinator.Coordinator(clean_stop_exception_types=[]) + + def wrapped_first_fn(): + with coord.stop_on_exception(): + self.assertFalse(second_fn_done.is_set()) + first_fn() + first_fn_done.set() + + self.assertFalse(first_fn_done.is_set()) + t = threading.Thread(target=wrapped_first_fn) + t.start() + + second_fn() + self.assertTrue(first_fn_done.is_set()) + second_fn_done.set() + + coord.join([t]) + + def testWaitRaiseErrorAfterMarkFailure(self): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue = coordinator_lib._CoordinatedClosureQueue() + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + closure = closure_queue.get() + + wait_finish_event = threading.Event() + coord = coordinator.Coordinator(clean_stop_exception_types=[]) + + # Using a thread to verify that closure_queue.wait() will not return until + # all inflight closures are finished. + + def mark_finished_fn(): + try: + raise ValueError('Some error.') + except ValueError as e: + closure_queue.mark_failed(e) + + def wait_fn(): + with self.assertRaises(ValueError): + closure_queue.wait() + + self._assert_one_unblock_the_other(mark_finished_fn, wait_fn) + + self.assertTrue(closure_queue.done()) + + def _create_closure(self, cancellation_mgr): + + @def_function.function() + def some_function(): + return 1.0 + + return coordinator_lib.Closure(some_function, cancellation_mgr) + + def _put_two_closures_and_get_one(self): + closure_queue = coordinator_lib._CoordinatedClosureQueue() + closure1 = self._create_closure(closure_queue._cancellation_mgr) + closure_queue.put(closure1) + + closure2 = self._create_closure(closure_queue._cancellation_mgr) + closure_queue.put(closure2) + + closure_got = closure_queue.get() # returns closure1 + self.assertIs(closure_got, closure1) + self.assertIsNot(closure_got, closure2) + return closure_queue, closure1, closure2 + + def testPutRaiseError(self): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue, _, closure2 = self._put_two_closures_and_get_one() + + closure_queue.mark_failed(ValueError()) + + with self.assertRaises(ValueError): + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + + self.assertTrue(closure_queue.done()) + + with self.assertRaisesRegex( + errors.CancelledError, + 'The corresponding function is cancelled. Please reschedule the ' + 'function.'): + closure2.output_remote_value.fetch() + + # The error is cleared. + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + + def testWaitRaiseError(self): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue, _, closure2 = self._put_two_closures_and_get_one() + + closure_queue.mark_failed(ValueError()) + + with self.assertRaises(ValueError): + closure_queue.wait() + self.assertTrue(closure_queue.done()) + + with self.assertRaisesRegex( + errors.CancelledError, + 'The corresponding function is cancelled. Please reschedule the ' + 'function.'): + closure2.output_remote_value.fetch() + + # The error is cleared. + closure_queue.wait() + + def testDoneRaiseError(self): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue, _, _ = self._put_two_closures_and_get_one() + + self.assertFalse(closure_queue.done()) + closure_queue.mark_failed(ValueError()) + with self.assertRaises(ValueError): + closure_queue.done() + + def _set_error(self, closure_queue, closure, error): + try: + raise error + except Exception as e: # pylint: disable=broad-except + closure.output_remote_value._set_error(e) + closure_queue.mark_failed(e) + + def _test_cancel_closure_when_error(self, call_wait): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue, closure1, closure2 = self._put_two_closures_and_get_one() + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + closure_queue.get() + # At this moment, there are two inflight, one in queue. + self.assertEqual(closure_queue._inflight_closure_count, 2) + + # Hold a copy of the queue's cancellation manager at this point + initial_cm = closure_queue._cancellation_mgr + + # Simulating closure1 fails. + self._set_error(closure_queue, closure1, ValueError('Some error.')) + + # At this moment, there are one inflight, one in queue. + self.assertEqual(closure_queue._queue.qsize(), 1) + self.assertEqual(closure_queue._inflight_closure_count, 1) + + closure3 = self._create_closure(closure_queue._cancellation_mgr) + + def fake_cancellation(): + self._set_error(closure_queue, closure2, + ValueError('Fake cancellation error.')) + + def report_error(): + # It should not report the fake cancellation error. + with self.assertRaisesRegex(ValueError, 'Some error.'): + # Verifying `wait()` or `put()` raises even if one closure is in + # flight. + if call_wait: + closure_queue.wait() + else: + closure_queue.put(closure3) + + self._assert_one_unblock_the_other(fake_cancellation, report_error) + + # The original cancellation manager of the queue has been cancelled. + self.assertTrue(initial_cm.is_cancelled) + + # At this moment, there is zero inflight, nothing in queue. + self.assertTrue(closure_queue._queue.empty()) + self.assertEqual(closure_queue._inflight_closure_count, 0) + self.assertIsNone(closure_queue._error) + + # This asserts that closure1 has errored. + with self.assertRaisesRegex(ValueError, 'Some error.'): + closure1.output_remote_value.fetch() + + # The following asserts that closure3 should have been cancelled. + if not call_wait: + with self.assertRaisesRegex( + errors.CancelledError, + 'The corresponding function is cancelled. Please reschedule the ' + 'function.'): + closure3.output_remote_value.fetch() + + # Closure2 was an inflight closure when it got cancelled. + self.assertEqual(closure2.output_remote_value._status, + coordinator_lib._RemoteValueStatus.READY) + with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'): + closure2.output_remote_value.fetch() + + # This asserts that the queue has a clear state. + self.testBasic() + + def testWaitRaiseErrorAfterCancelClosure(self): + self._test_cancel_closure_when_error(call_wait=True) + + def testPutRaiseErrorAfterCancelClosure(self): + self._test_cancel_closure_when_error(call_wait=False) + + def testStateIsRestoredAfterJoinIsCalled(self): + if sys.version_info >= (3, 8) and platform.system() == 'Windows': + # TODO(b/165013260): Fix this + self.skipTest('Test is currently broken on Windows with Python 3.8') + + closure_queue, _, _ = self._put_two_closures_and_get_one() + self.assertEqual(closure_queue._inflight_closure_count, 1) + closure_queue.mark_failed(ValueError('test error')) + with self.assertRaises(ValueError): + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + + # Its error should have been cleared. + self.assertIsNone(closure_queue._error) + closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) + self.assertIsNone(closure_queue._error) + + def testThreadSafey(self): + thread_count = 10 + queue = coordinator_lib._CoordinatedClosureQueue() + + # Each thread performs 20 queue actions: 10 are `put_back` and 10 are + # `mark_finished`. + action_count = 20 + + def func(): + for i in range(action_count): + closure = queue.get() + if i % 2 == 0: + queue.put_back(closure) + else: + queue.mark_finished() + + threads = [threading.Thread(target=func) for i in range(thread_count)] + for t in threads: + t.start() + + for _ in range(thread_count * action_count // 2): + queue.put(self._create_closure(queue._cancellation_mgr)) + queue.wait() + self.assertTrue(queue.done()) + + +class ErrorReportingThread(threading.Thread): + + error = None + + def __init__(self, *args, **kwargs): + assert 'target' in kwargs + target = kwargs['target'] + + @functools.wraps(target) + def wrapped_target(*args, **kwargs): + try: + return target(*args, **kwargs) + except Exception as e: # pylint: disable=broad-except + ErrorReportingThread.error = e + + kwargs['target'] = wrapped_target + super(ErrorReportingThread, self).__init__(*args, **kwargs) + + +class TestCaseWithErrorReportingThread(test.TestCase): + + @classmethod + def setUpClass(cls): + cls._threading_thread = threading.Thread + threading.Thread = ErrorReportingThread + super(TestCaseWithErrorReportingThread, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + super(TestCaseWithErrorReportingThread, cls).tearDownClass() + threading.Thread = cls._threading_thread + + def setUp(self): + ErrorReportingThread.error = None + super(TestCaseWithErrorReportingThread, self).setUp() + + def tearDown(self): + super(TestCaseWithErrorReportingThread, self).tearDown() + if ErrorReportingThread.error: + raise ErrorReportingThread.error # pylint: disable=raising-bad-type + + +def make_coordinator(num_workers, num_ps): + # TODO(rchao): Test the internal rpc_layer version. + cluster_def = multi_worker_test_base.create_in_process_cluster( + num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') + cluster_def['chief'] = [ + 'localhost:%d' % multi_worker_test_base.pick_unused_port() + ] + cluster_resolver = SimpleClusterResolver( + ClusterSpec(cluster_def), rpc_layer='grpc') + strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + cluster_resolver) + return coordinator_lib.ClusterCoordinator(strategy) + + +class ClusterCoordinatorTest(TestCaseWithErrorReportingThread): + + @classmethod + def setUpClass(cls): + super(ClusterCoordinatorTest, cls).setUpClass() + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + cls.strategy = cls.coordinator.strategy + + def testFnReturnNestedValues(self): + x = constant_op.constant(1) + + @def_function.function + def f(): + return x + 1, (x + 2, x + 3), [x + 4], {'v': x} + + got = self.coordinator.schedule(f) + want = 2, (3, 4), [5], {'v': 1} + self.assertEqual(got.fetch(), want) + self.assertEqual(self.coordinator.fetch(got), want) + + def testFetchingRemoteValueStructure(self): + x = constant_op.constant(1) + + @def_function.function + def f(): + return x + 1, (x + 2, x + 3), [x + 4], {'v': x} + + want = 2, (3, 4), [5], {'v': 1} + remote_value_list = [self.coordinator.schedule(f) for _ in range(5)] + self.assertAllEqual( + self.coordinator.fetch(remote_value_list), [want for _ in range(5)]) + + def testInputFunction(self): + + def input_fn(): + return dataset_ops.DatasetV2.range(1, 2) + + with self.strategy.scope(): + v = variables.Variable(initial_value=0, dtype=dtypes.int64) + + @def_function.function + def worker_fn(iterator): + x = next(iterator) + v.assign_add(x) + return x + + distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) + result = self.coordinator.schedule( + worker_fn, args=(iter(distributed_dataset),)) + result = self.coordinator.fetch(result) + self.assertEqual(result, (1,)) + result = self.coordinator.schedule( + worker_fn, args=(iter(distributed_dataset),)) + result = self.coordinator.fetch(result) + self.assertEqual(result, (1,)) + + self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6) + + def testAsyncScheduleAndJoin(self): + + def input_fn(): + return dataset_ops.DatasetV2.from_tensor_slices([2] * 10) + + with self.strategy.scope(): + v = variables.Variable(initial_value=0, dtype=dtypes.int32) + + # TODO(yuefengz): the following tf.function has a return value which is None + # in its structured_outputs. + @def_function.function + def worker_fn(iterator): + x = next(iterator) + v.assign_add(x) + + distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) + + iterator = iter(distributed_dataset) + + # Verifying joining without any scheduling doesn't hang. + self.coordinator.join() + self.assertEqual(v.read_value().numpy(), 0) + + for _ in range(5): + self.coordinator.schedule(worker_fn, args=(iterator,)) + self.coordinator.join() + + # With 5 addition it should be 2*5 = 10. + self.assertEqual(v.read_value().numpy(), 10) + + for _ in range(5): + self.coordinator.schedule(worker_fn, args=(iterator,)) + + # Verifying multiple join is fine. + self.coordinator.join() + self.coordinator.join() + self.coordinator.join() + + self.assertTrue(self.coordinator.done()) + + # Likewise, it's now 20. + self.assertEqual(v.read_value().numpy(), 20) + + def testInputFunctionWithMap(self): + self._map_fn_tracing_count = 0 + + def input_fn(): + + def map_fn(x): + self._map_fn_tracing_count += 1 + return x + 10 + + return dataset_ops.DatasetV2.range(0, 10).map(map_fn) + + @def_function.function + def worker_fn(iterator): + return next(iterator) + + distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) + result = self.coordinator.schedule( + worker_fn, args=(iter(distributed_dataset),)) + self.assertEqual(result.fetch(), (10,)) + self.assertEqual(self._map_fn_tracing_count, 1) + + def testInputFunctionCreateVariables(self): + + def input_fn(): + v = variables.Variable(initial_value=0.0) + return v.read_value() + + with self.assertRaises(ValueError): + self.coordinator.create_per_worker_dataset(input_fn) + + def testDatasetsShuffledDifferently(self): + # This test requires at least two workers in the cluster. + self.assertGreaterEqual(len(self.coordinator.cluster.workers), 2) + + random_seed.set_random_seed(None) + + def input_fn(): + return dataset_ops.DatasetV2.range(0, 100).shuffle(100) + + distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) + distributed_iterator = iter(distributed_dataset) + + # Get elements from the first two iterators. + iterator_1 = distributed_iterator._values[0] + iterator_1._rebuild_on(self.coordinator.cluster.workers[0]) + iterator_1 = iterator_1.fetch() + elements_in_iterator_1 = [e.numpy() for e in iterator_1] + + iterator_2 = distributed_iterator._values[1] + iterator_2._rebuild_on(self.coordinator.cluster.workers[1]) + iterator_2 = iterator_2.fetch() + elements_in_iterator_2 = [e.numpy() for e in iterator_2] + + self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2) + + def testPerWorkerValue(self): + self.skipTest('b/168569314') + var_shape = tuple() + var_dtype = dtypes.float32 + var_name = 'var' + + def create_var(): + var = variables.Variable( + initial_value=0.0, dtype=var_dtype, name=var_name) + self.assertIn('worker', var.device) + return var + + worker_local_var = self.coordinator._create_per_worker_resources(create_var) + + # The following is a workaround to allow `worker_local_var` to be passed in + # as args to the `coordinator.schedule` method which requires tensor specs + # to trace tf.function but _create_worker_resources' return values don't + # have tensor specs. We can get rid of this workaround once + # _create_worker_resources is able to infer the tensor spec of the return + # value of the function passed in. See b/154675763. + for var in worker_local_var._values: + var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name) + + def worker_fn(var): + var.assign_add(1.0) + + for _ in range(10): + # Which slice of `worker_local_var` will be used will depend on which + # worker the `worker_fn` gets scheduled on. + self.coordinator.schedule(worker_fn, args=(worker_local_var,)) + self.coordinator.join() + + var_sum = sum(self.coordinator.fetch(worker_local_var._values)) + self.assertEqual(var_sum, 10.0) + + def testDisallowRemoteValueAsInput(self): + + @def_function.function + def func_0(): + return 1.0 + + @def_function.function + def func_1(x): + return x + 1.0 + + remote_v = self.coordinator.schedule(func_0) + with self.assertRaises(ValueError): + self.coordinator.schedule(func_1, args=(remote_v,)) + + +class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest): + """Test basic functionality works with explicit maximum closure queue size. + + Execute the same set of test cases as in `ClusterCoordinatorTest`, with an + explicit size limit for the closure queue. Note that even when the queue size + is set to infinite, there is still a maximum practical size (depends on host + memory limit) that might cause the queue.put operations to be blocking when + scheduling a large number of closures on a big cluster. These tests make sure + that the coordinator does not run into deadlocks in such scenario. + """ + + @classmethod + def setUpClass(cls): + super(LimitedClosureQueueSizeBasicTest, cls).setUpClass() + coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + cls.strategy = cls.coordinator.strategy + + +class ScheduleStartDelayTest(ClusterCoordinatorTest): + """Test basic functionality works with worker scheduling delay. + + This is basically to make sure that setting environment variables + `TF_COORDINATOR_SCHEDULE_START_DELAY` and + `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure. + """ + + @classmethod + def setUpClass(cls): + super(ScheduleStartDelayTest, cls).setUpClass() + os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2' + os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4' + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + cls.strategy = cls.coordinator.strategy + + @classmethod + def tearDownClass(cls): + del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] + del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] + super(ScheduleStartDelayTest, cls).tearDownClass() + + +class ErrorReportingTest(TestCaseWithErrorReportingThread): + + @classmethod + def setUpClass(cls): + super(ErrorReportingTest, cls).setUpClass() + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + cls.strategy = cls.coordinator.strategy + + with cls.strategy.scope(): + cls.iteration = variables.Variable(initial_value=0.0) + + @def_function.function + def _normal_function(self): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + self.iteration.assign_add(1.0) + return math_ops.reduce_mean(math_ops.matmul(x, y)) + + @def_function.function + def _error_function(self): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y))) + self.iteration.assign_add(1.0) + return self.iteration + + @def_function.function + def _long_function(self): + x = random_ops.random_uniform((1000, 1000)) + for _ in math_ops.range(10000): + a = random_ops.random_uniform((1000, 1000)) + b = random_ops.random_uniform((1000, 1000)) + x += math_ops.matmul(a, b) + return x + + def testJoinRaiseError(self): + for _ in range(3): + self.coordinator.schedule(self._normal_function) + self.coordinator.schedule(self._error_function) + with self.assertRaises(errors.InvalidArgumentError): + self.coordinator.join() + + def testScheduleRaiseError(self): + for _ in range(3): + self.coordinator.schedule(self._normal_function) + self.coordinator.schedule(self._error_function) + with self.assertRaises(errors.InvalidArgumentError): + while True: + self.coordinator.schedule(self._normal_function) + + def testScheduleRaiseErrorWithMultipleFailure(self): + for _ in range(3): + self.coordinator.schedule(self._normal_function) + self.coordinator.schedule(self._error_function) + with self.assertRaises(errors.InvalidArgumentError): + while True: + self.coordinator.schedule(self._error_function) + self.coordinator.join() + + def testErrorWillbeCleared(self): + self.coordinator.schedule(self._error_function) + with self.assertRaises(errors.InvalidArgumentError): + self.coordinator.join() + + for _ in range(3): + self.coordinator.schedule(self._normal_function) + self.coordinator.schedule(self._error_function) + with self.assertRaises(errors.InvalidArgumentError): + self.coordinator.join() + + def testRemoteValueReturnError(self): + result = self.coordinator.schedule(self._error_function) + + with self.assertRaises(errors.InvalidArgumentError): + result.fetch() + + # Clear the error. + with self.assertRaises(errors.InvalidArgumentError): + self.coordinator.join() + + def testInputError(self): + + worker_local_val = self.coordinator._create_per_worker_resources( + self._error_function) + + @def_function.function + def func(x): + return x + 1 + + result = self.coordinator.schedule(func, args=(worker_local_val,)) + with self.assertRaises(coordinator_lib.InputError): + self.coordinator.join() + + with self.assertRaises(coordinator_lib.InputError): + result.fetch() + + def testCancellation(self): + for _ in range(3): + self.coordinator.schedule(self._normal_function) + long_function = self.coordinator.schedule(self._long_function) + self.coordinator.schedule(self._error_function) + + with self.assertRaises(errors.InvalidArgumentError): + self.coordinator.join() + + with self.assertRaises(errors.CancelledError): + long_function.fetch() + + for _ in range(3): + self.coordinator.schedule(self._normal_function) + self.coordinator.join() + + +class LimitedClosureQueueErrorTest(ErrorReportingTest): + """Test error reporting works with explicit maximum closure queue size. + + Execute the same set of test cases as in ErrorReportingTest, with an explicit + size limit for the closure queue. + """ + + @classmethod + def setUpClass(cls): + super(LimitedClosureQueueErrorTest, cls).setUpClass() + coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + cls.strategy = cls.coordinator.strategy + + with cls.coordinator.strategy.scope(): + cls.iteration = variables.Variable(initial_value=0.0) + + +class StrategyIntegrationTest(test.TestCase): + + @classmethod + def setUpClass(cls): + super(StrategyIntegrationTest, cls).setUpClass() + cls.coordinator = make_coordinator(num_workers=1, num_ps=1) + cls.strategy = cls.coordinator.strategy + + def testBasicVariableAssignment(self): + self.strategy.extended._variable_count = 0 + with self.strategy.scope(): + v1 = variables.Variable(initial_value=0.0) + v2 = variables.Variable(initial_value=1.0) + self.assertEqual(self.strategy.extended._variable_count, 2) + + @def_function.function + def worker_fn(): + v1.assign_add(0.1) + v2.assign_sub(0.2) + return v1.read_value() / v2.read_value() + + results = self.coordinator.schedule(worker_fn) + logging.info('Results of experimental_run_v2: %f', + self.coordinator.fetch(results)) + + self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6) + self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6) + + def testRunAndReduce(self): + self.assertFalse(distribution_strategy_context.in_cross_replica_context()) + with self.strategy.scope(): + self.assertTrue(distribution_strategy_context.in_cross_replica_context()) + v = variables.Variable(initial_value=1) + + @def_function.function + def worker_fn(input_tensor): + + def replica_fn(input_tensor): + # Within `replica_fn`, it has to be in a replica context. + self.assertFalse( + distribution_strategy_context.in_cross_replica_context()) + return input_tensor + v, input_tensor - v + + run_result = self.strategy.run(replica_fn, args=(input_tensor,)) + reduced_result = self.strategy.reduce('SUM', run_result, axis=None) + check_ops.assert_equal_v2(run_result, (4, 2)) + check_ops.assert_equal_v2(reduced_result, (4, 2)) + return reduced_result + + # Asserting scheduling in scope has the expected behavior. + result = self.coordinator.schedule( + worker_fn, args=(constant_op.constant(3),)) + self.assertIsInstance(result, coordinator_lib.RemoteValue) + self.assertEqual(result.fetch(), (4, 2)) + + # Asserting scheduling out of scope has the expected behavior. + result = self.coordinator.schedule( + worker_fn, args=(constant_op.constant(3),)) + self.assertEqual(result.fetch(), (4, 2)) + + def testDistributeDataset(self): + + def per_worker_dataset_fn(): + dataset = dataset_ops.DatasetV2.range(1, 2) + return self.strategy.experimental_distribute_dataset(dataset) + + @def_function.function + def worker_fn(iterator): + return next(iterator) + + distributed_dataset = self.coordinator.create_per_worker_dataset( + per_worker_dataset_fn) + result = self.coordinator.schedule( + worker_fn, args=(iter(distributed_dataset),)) + result = result.fetch() + self.assertEqual(result, (1,)) + + def testDistributeDatasetsFromFunction(self): + + def per_worker_dataset_fn(): + return self.strategy.distribute_datasets_from_function( + lambda _: dataset_ops.DatasetV2.range(1, 2)) + + @def_function.function + def worker_fn(iterator): + return next(iterator) + + distributed_dataset = self.coordinator.create_per_worker_dataset( + per_worker_dataset_fn) + result = self.coordinator.schedule( + worker_fn, args=(iter(distributed_dataset),)) + result = result.fetch() + self.assertEqual(result, (1,)) + + def testCallingDistributeDatasetOutside(self): + with self.assertRaises(ValueError): + dataset = dataset_ops.DatasetV2.range(1, 2) + self.strategy.experimental_distribute_dataset(dataset) + + with self.assertRaises(ValueError): + self.strategy.distribute_datasets_from_function( + lambda _: dataset_ops.DatasetV2.range(1, 2)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py new file mode 100644 index 00000000000..cc075d09c3d --- /dev/null +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -0,0 +1,393 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fault tolerance test for parameter server training in TF2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import threading +import time + +from tensorflow.python.compat import v2_compat +from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import parameter_server_strategy_v2 +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.coordinator import cluster_coordinator +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator as thread_coordinator +from tensorflow.python.training import server_lib + +_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker" +_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" + + +class Model(object): + + def __init__(self, coordinator): + self.cluster_coord = coordinator + self.strategy = self.cluster_coord.strategy + with self.cluster_coord.strategy.scope(): + self.build() + + def build(self): + self.w = variables.Variable( + initial_value=random_ops.random_uniform((1000, 1000)), + dtype=dtypes.float32) + self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) + + @def_function.function + def train_fn(self): + # train_fn roughly took 0.5s to execute on Intel Xeon Gold 6154 (3.00GHZ) + # w/o any compilation optimization (two worker setup). + for _ in math_ops.range(5): + x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), self.w) + self.w.assign_add(x) + self.iterations.assign_add(1) + + def schedule_training_functions(self, num_steps): + with self.strategy.scope(): + for _ in range(num_steps): + self.cluster_coord.schedule(self.train_fn) + + def join_training_functions(self): + self.cluster_coord.join() + + +class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring + + NUM_WORKERS = 2 + NUM_PS = 2 + + def setUp(self): + super(FaultToleranceTest, self).setUp() + + # Set the environment variable to prevent hanging upon job failure and + # restart. Note that it defaults to 'use_caller' at Google, but defaults + # to False in OSS. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + self._cluster = multi_worker_test_base.create_multi_process_cluster( + num_workers=FaultToleranceTest.NUM_WORKERS, + num_ps=FaultToleranceTest.NUM_PS, + rpc_layer="grpc") + self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() + self._cluster_def["chief"] = [ + "localhost:%d" % multi_worker_test_base.pick_unused_port() + ] + cluster_resolver = SimpleClusterResolver( + server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") + + # The strategy's constructor would connect to the cluster. + self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + cluster_resolver) + self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) + + self.thread_coord = thread_coordinator.Coordinator( + clean_stop_exception_types=[]) + + def tearDown(self): + super(FaultToleranceTest, self).tearDown() + self._cluster.stop() + + def _restart(self, downtime_secs, job): + """Kills `job` (index: 0) and restarts it after `downtime_secs`. + + Args: + downtime_secs: secs before restarting the job. + job: a string specifying the job to restart. + """ + self._cluster.kill_task(job, 0) + time.sleep(downtime_secs) + self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job)) + self._cluster.start_task(job, 0) + while not context.check_alive("/job:%s/replica:0/task:0" % job): + time.sleep(1) + + def _restart_in_thread(self, downtime_secs, restart_job): + + def _restart_fn(): + with self.thread_coord.stop_on_exception(): + self._restart(downtime_secs, restart_job) + + restart_thread = threading.Thread(target=_restart_fn) + restart_thread.start() + return restart_thread + + def testOneWorkerPreemption(self): + # A blackbox test to make sure the model can still train when there is + # worker preemption. + model = Model(self.cluster_coord) + model.schedule_training_functions(10) + + time.sleep(1) # Let it run a couple steps. + self.assertFalse( + self.cluster_coord.done(), "cluster finishes work before restart, this" + " is most likely due to the test runs in more powerful machine" + " compared to the one it previously runs. This setup is brittle but" + " there are no easy better alternatives. To fix the failure, consider" + " adding more work to the cluster, e.g, scheduling more functions.") + self._restart(5, "worker") + + model.join_training_functions() + self.assertGreaterEqual(model.iterations.numpy(), 10) + + def testOneWorkerPreemptionWithCancellation(self): + + @def_function.function + def normal_function(): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + return math_ops.reduce_mean(math_ops.matmul(x, y)) + + @def_function.function + def error_function(): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + check_ops.assert_non_positive_v2( + math_ops.reduce_sum(math_ops.matmul(x, y))) + return x + + @def_function.function + def long_function(): + x = random_ops.random_uniform((1000, 1000)) + for _ in math_ops.range(10000): + a = random_ops.random_uniform((1000, 1000)) + b = random_ops.random_uniform((1000, 1000)) + x += math_ops.matmul(a, b) + return x + + for _ in range(3): + self.cluster_coord.schedule(normal_function) + long_function_result = self.cluster_coord.schedule(long_function) + self.cluster_coord.schedule(error_function) + + time.sleep(1) # Let it run a couple steps. + self._restart(1, "worker") + + with self.assertRaises(errors.InvalidArgumentError): + self.cluster_coord.join() + + with self.assertRaises(errors.CancelledError): + long_function_result.fetch() + + for _ in range(3): + self.cluster_coord.schedule(normal_function) + self.cluster_coord.join() + + def testHandleDatasetCreationFailure(self): + model = Model(self.cluster_coord) + + restart_thread = self._restart_in_thread(5, "worker") + + model.schedule_training_functions(3) + model.join_training_functions() + + self.thread_coord.join([restart_thread]) + self.assertGreaterEqual(model.iterations.numpy(), 3) + + def testWorkerPreemptionErrorType(self): + + @def_function.function + def worker_train_fn(): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + return math_ops.reduce_mean(math_ops.matmul(x, y)) + + def run_fn(): + with self.thread_coord.stop_on_exception(): + with ops.device("/job:worker/replica:0/task:0"): + for _ in range(3): + for _ in range(3): + worker_train_fn() + time.sleep(5) + + run_thread = threading.Thread(target=run_fn) + run_thread.start() + time.sleep(1) # Let it run a couple steps. + self._restart(2, "worker") + + try: + self.thread_coord.join([run_thread]) + except errors.UnavailableError as e: + logging.info("Got exception %r, error message is %s", e, e) + + self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except + self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) + + self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or + "Socket closed" in str(e) or + "Connection reset by peer" in str(e) or + "Transport closed" in str(e)) + + def testWorkerPreemptionErrorTypeWithPythonFunction(self): + + def worker_train_fn(): + x = random_ops.random_uniform((2, 10)) + y = random_ops.random_uniform((10, 2)) + return math_ops.reduce_mean(math_ops.matmul(x, y)) + + def run_fn(): + with self.thread_coord.stop_on_exception(): + with ops.device("/job:worker/replica:0/task:0"): + for _ in range(3): + for _ in range(3): + worker_train_fn() + time.sleep(5) + + run_thread = threading.Thread(target=run_fn) + run_thread.start() + time.sleep(1) # Let it run a couple steps. + self._restart(2, "worker") + + try: + self.thread_coord.join([run_thread]) + except errors.UnavailableError as e: + logging.info("Got exception %r, error message is %s", e, e) + + self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except + self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) + + self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or + "Socket closed" in str(e) or + "Connection reset by peer" in str(e) or + "Transport closed" in str(e)) + + def testPSPreemptionErrorType(self): + + with ops.device("/job:ps/replica:0/task:0"): + v = variables.Variable( + initial_value=random_ops.random_uniform((2, 10)), + dtype=dtypes.float32) + + @def_function.function + def worker_train_fn(): + y = random_ops.random_uniform((10, 2)) + return math_ops.reduce_mean(math_ops.matmul(v, y)) + + def run_fn(): + with self.thread_coord.stop_on_exception(): + with ops.device("/job:worker/replica:0/task:0"): + for _ in range(3): + for _ in range(3): + worker_train_fn() + time.sleep(5) + + run_thread = threading.Thread(target=run_fn) + run_thread.start() + time.sleep(1) # Let it run a couple steps. + + # Use a short restart delay to cover the case that RPC channel is reused + self._restart(1, "ps") + + try: + self.thread_coord.join([run_thread]) + except (errors.UnavailableError, errors.AbortedError) as e: + logging.info("Got exception %r, error message is %s", e, e) + self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except + + if isinstance(e, errors.UnavailableError): + self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or + "Socket closed" in str(e) or + "Connection reset by peer" in str(e) or + "Transport closed" in str(e)) + + if isinstance(e, errors.AbortedError): + self.assertIn("RecvTensor expects a different device incarnation", + str(e)) + + def testTwoWorkersPreempted(self): + model = Model(self.cluster_coord) + model.schedule_training_functions(10) + + time.sleep(1) + self.assertFalse(self.cluster_coord.done()) + self._cluster.kill_task("worker", 0) + self._cluster.kill_task("worker", 1) + time.sleep(2) + self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) + self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) + self._cluster.start_task("worker", 0) + self._cluster.start_task("worker", 1) + time.sleep(2) + self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) + self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) + + model.join_training_functions() + self.assertGreaterEqual(model.iterations.numpy(), 10) + + def testWorkerContinuousFailure(self): + model = Model(self.cluster_coord) + model.schedule_training_functions(10) + + time.sleep(1) + self.assertFalse(self.cluster_coord.done()) + self._cluster.kill_task("worker", 0) + time.sleep(2) + self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) + self._cluster.start_task("worker", 0) + time.sleep(2) + self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) + self._cluster.kill_task("worker", 0) + time.sleep(2) + self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) + self._cluster.start_task("worker", 0) + time.sleep(2) + self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) + + model.join_training_functions() + self.assertGreaterEqual(model.iterations.numpy(), 10) + + def testClusterStateNotDisrupted(self): + # This test has side effects and can disrupt other tests, even if the + # resource created by it will not be used in following tests. + # TODO(b/155209534): enable this test. + # self.testPSPreemptionErrorType() + + self.thread_coord = thread_coordinator.Coordinator( + clean_stop_exception_types=[]) + self.testOneWorkerPreemption() + + self.thread_coord = thread_coordinator.Coordinator( + clean_stop_exception_types=[]) + self.testWorkerPreemptionErrorType() + + # In previous tests, workers may fail after training is done. But the + # following tests start with creating resources where failure is not + # handled. + # TODO(b/153888707): enable the following two tests. + # self.testTwoWorkersPreempted() + # self.testWorkerContinuousFailure() + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + multi_process_runner.test_main() diff --git a/tensorflow/python/distribute/client/metric_utils.py b/tensorflow/python/distribute/coordinator/metric_utils.py similarity index 91% rename from tensorflow/python/distribute/client/metric_utils.py rename to tensorflow/python/distribute/coordinator/metric_utils.py index f0a6628a333..308da213904 100644 --- a/tensorflow/python/distribute/client/metric_utils.py +++ b/tensorflow/python/distribute/coordinator/metric_utils.py @@ -31,15 +31,15 @@ enable_metrics = False _time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6) _function_tracing_sampler = monitoring.Sampler( - '/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets, + '/tensorflow/api/ps_strategy/coordinator/function_tracing', _time_buckets, 'Sampler to track the time (in seconds) for tracing functions.') _closure_execution_sampler = monitoring.Sampler( - '/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets, + '/tensorflow/api/ps_strategy/coordinator/closure_execution', _time_buckets, 'Sampler to track the time (in seconds) for executing closures.') _remote_value_fetch_sampler = monitoring.Sampler( - '/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets, + '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', _time_buckets, 'Sampler to track the time (in seconds) for fetching remote_value.') _METRICS_MAPPING = { diff --git a/tensorflow/python/distribute/client/metric_utils_test.py b/tensorflow/python/distribute/coordinator/metric_utils_test.py similarity index 90% rename from tensorflow/python/distribute/client/metric_utils_test.py rename to tensorflow/python/distribute/coordinator/metric_utils_test.py index f94cdcb6d76..abd4221df4d 100644 --- a/tensorflow/python/distribute/client/metric_utils_test.py +++ b/tensorflow/python/distribute/coordinator/metric_utils_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for metrics collecting in client.""" +"""Tests for metrics collecting in coordinator.""" from __future__ import absolute_import from __future__ import division @@ -22,9 +22,9 @@ from __future__ import print_function import time from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.client import client -from tensorflow.python.distribute.client import metric_utils from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib +from tensorflow.python.distribute.coordinator import metric_utils from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.training.server_lib import ClusterSpec @@ -35,7 +35,7 @@ class MetricUtilsTest(test.TestCase): def get_rpc_layer(self): return 'grpc' - def testClientMetrics(self): + def testClusterCoordinatorMetrics(self): metric_utils.enable_metrics = True @@ -48,7 +48,7 @@ class MetricUtilsTest(test.TestCase): ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer()) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) - cluster = client.Cluster(strategy) + cluster = coordinator_lib.Cluster(strategy) @def_function.function def func(): diff --git a/tensorflow/python/distribute/client/utils.py b/tensorflow/python/distribute/coordinator/utils.py similarity index 100% rename from tensorflow/python/distribute/client/utils.py rename to tensorflow/python/distribute/coordinator/utils.py diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 6d2a4e16f84..c5aca728827 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import enum import threading import six @@ -34,6 +33,7 @@ from tensorflow.python.distribute import ps_values from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as value_lib +from tensorflow.python.distribute import values_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import executor as executor_lib @@ -250,11 +250,7 @@ class CrossDeviceOps(object): # Returns 1 by default, the value may be overridden by sub classes. return 1 - def reduce(self, - reduce_op, - per_replica_value, - destinations, - experimental_hints=None): + def reduce(self, reduce_op, per_replica_value, destinations, options=None): """Reduce `per_replica_value` to `destinations`. See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in @@ -271,8 +267,8 @@ class CrossDeviceOps(object): `destinations`. Note that if it's a `tf.Variable`, the value is reduced to the devices of that variable, and this method doesn't update the variable. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A `tf.Tensor` or `tf.distribute.DistributedValues`. @@ -282,6 +278,8 @@ class CrossDeviceOps(object): `tf.distribute.DistributedValues` or if destinations is not a string, `tf.Variable` or `tf.distribute.DistributedValues`. """ + if options is None: + options = collective_util.Options() if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica(per_replica_value) @@ -295,16 +293,12 @@ class CrossDeviceOps(object): v = array_ops.identity(per_replica_value.values[0]) return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if options is None: + options = collective_util.Options() return self.reduce_implementation(reduce_op, per_replica_value, - destinations, experimental_hints) + destinations, options) - def _gather(self, - per_replica_value, - destinations, - axis, - experimental_hints=None): + def _gather(self, per_replica_value, destinations, axis, options=None): """Gather `per_replica_value` to `destinations`. Args: @@ -318,8 +312,8 @@ class CrossDeviceOps(object): variable. axis: specifies the dimension to gather along within each replica's tensor. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A `tf.Tensor` or `tf.distribute.DistributedValues` @@ -329,8 +323,11 @@ class CrossDeviceOps(object): `tf.distribute.DistributedValues` or if destinations is not a string, `tf.Variable` or `tf.distribute.DistributedValues`. """ - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if isinstance(per_replica_value, ops.IndexedSlices): + raise NotImplementedError("gather/all_gather does not support " + "IndexedSlices") + if options is None: + options = collective_util.Options() if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica(per_replica_value) @@ -346,10 +343,10 @@ class CrossDeviceOps(object): return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) return self._gather_implementation(per_replica_value, destinations, axis, - experimental_hints) + options) def _gather_implementation(self, per_replica_value, destinations, axis, - experimental_hints): + options): """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`. Overriding this method is useful for subclass implementers. @@ -365,8 +362,8 @@ class CrossDeviceOps(object): variable. axis: specifies the dimension to gather along within each replica's tensor. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A `tf.Tensor` or `tf.distribute.DistributedValues`. @@ -379,10 +376,7 @@ class CrossDeviceOps(object): raise NotImplementedError( "_gather method must be implemented in descendants.") - def batch_reduce(self, - reduce_op, - value_destination_pairs, - experimental_hints=None): + def batch_reduce(self, reduce_op, value_destination_pairs, options=None): """Reduce values to destinations in batches. See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be @@ -393,8 +387,8 @@ class CrossDeviceOps(object): combined. value_destination_pairs: a sequence of (value, destinations) pairs. See `tf.distribute.CrossDeviceOps.reduce` for descriptions. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair @@ -404,6 +398,8 @@ class CrossDeviceOps(object): ValueError: if `value_destination_pairs` is not an iterable of tuples of `tf.distribute.DistributedValues` and destinations. """ + if options is None: + options = collective_util.Options() # TODO(yuefengz): if destinations are different, split into several # `_batch_reduce` invocations. if not _validate_value_destination_pairs(value_destination_pairs): @@ -424,10 +420,10 @@ class CrossDeviceOps(object): for v, _ in value_destination_pairs ] - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if options is None: + options = collective_util.Options() return self.batch_reduce_implementation(reduce_op, value_destination_pairs, - experimental_hints) + options) def broadcast(self, tensor, destinations): """Broadcast `tensor` to `destinations`. @@ -450,7 +446,7 @@ class CrossDeviceOps(object): @doc_controls.for_subclass_implementers def reduce_implementation(self, reduce_op, per_replica_value, destinations, - experimental_hints): + options): """Implementation of `reduce`. Overriding this method is useful for subclass implementers. @@ -466,8 +462,8 @@ class CrossDeviceOps(object): `destinations`. Note that if it's a `tf.Variable`, the value is reduced to the devices of that variable, this method doesn't update the variable. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A `tf.Tensor` or `tf.distribute.DistributedValues`. @@ -482,7 +478,7 @@ class CrossDeviceOps(object): @doc_controls.for_subclass_implementers def batch_reduce_implementation(self, reduce_op, value_destination_pairs, - experimental_hints): + options): """Implementation of `batch_reduce`. Overriding this method is useful for subclass implementers. @@ -492,8 +488,8 @@ class CrossDeviceOps(object): combined. value_destination_pairs: a sequence of (value, destinations) pairs. See `reduce` for descriptions. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints - to perform collective operations. + options: a `tf.distribute.experimental.CommunicationOptions`. See + `tf.distribute.experimental.CommunicationOptions` for details. Returns: A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair @@ -557,8 +553,8 @@ class ReductionToOneDevice(CrossDeviceOps): super(ReductionToOneDevice, self).__init__() def reduce_implementation(self, reduce_op, per_replica_value, destinations, - experimental_hints): - del experimental_hints # Unused. + options): + del options # Unused. if check_destinations(destinations): devices = get_devices_from(destinations) else: @@ -572,8 +568,8 @@ class ReductionToOneDevice(CrossDeviceOps): return self.broadcast(reduced, destinations) def _gather_implementation(self, per_replica_value, destinations, axis, - experimental_hints): - del experimental_hints # Unused. + options): + del options # Unused. if check_destinations(destinations): devices = get_devices_from(destinations) else: @@ -586,10 +582,10 @@ class ReductionToOneDevice(CrossDeviceOps): return self.broadcast(gathered, destinations) def batch_reduce_implementation(self, reduce_op, value_destination_pairs, - experimental_hints): + options): return [ self.reduce_implementation( - reduce_op, t, destinations=v, experimental_hints=experimental_hints) + reduce_op, t, destinations=v, options=options) for t, v in value_destination_pairs ] @@ -805,22 +801,25 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): super(AllReduceCrossDeviceOps, self).__init__() def reduce_implementation(self, reduce_op, per_replica_value, destinations, - experimental_hints): - del experimental_hints # Unused. - if _devices_match(per_replica_value, destinations): + options): + del options # Unused. + # To use NCCL or all-reduce, source and destination devices should match, + # and none of the devices should be CPU. + if (_devices_match(per_replica_value, destinations) and + not any("cpu" in d.lower() for d in get_devices_from(destinations))): return self._batch_all_reduce(reduce_op, [per_replica_value])[0] else: return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, destinations) def batch_reduce_implementation(self, reduce_op, value_destination_pairs, - experimental_hints): + options): if _all_devices_match(value_destination_pairs): return self._batch_all_reduce(reduce_op, [v[0] for v in value_destination_pairs]) else: return [ - self.reduce_implementation(reduce_op, value, dest, experimental_hints) + self.reduce_implementation(reduce_op, value, dest, options) for value, dest in value_destination_pairs ] @@ -881,13 +880,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): reduce_op, zip(sparse_values, sparse_values)) def _gather_implementation(self, per_replica_value, destinations, axis, - experimental_hints): + options): logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not " "supported. Falling back to gather on one device and " "then broadcast. We're working on a more efficient " "implementation.") return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access - experimental_hints) + options) # For compatibility with code using the old name of `AllReduceCrossDeviceOps`. @@ -978,20 +977,9 @@ class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): num_packs=num_packs) -@tf_export("distribute.experimental.CollectiveCommunication") -class CollectiveCommunication(enum.Enum): - """Communication choices for CollectiveOps. - - * `AUTO`: Default to runtime's automatic choices. - * `RING`: TensorFlow's ring algorithms for all-reduce and - all-gather. - * `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for - all-gather. - """ - AUTO = "AUTO" - RING = "RING" - NCCL = "NCCL" - # TODO(ayushd): add ncclAllGather implementation. +# TODO(crccw): remove after migrating all callers. +CollectiveCommunication = collective_util.CommunicationImplementation +CommunicationImplementation = collective_util.CommunicationImplementation # TODO(yuefengz): support in-graph collective all-reduce. @@ -1002,11 +990,7 @@ class CollectiveAllReduce(CrossDeviceOps): all workers and then put results on the right destinations. """ - def __init__(self, - devices, - group_size, - collective_keys=None, - communication=CollectiveCommunication.AUTO): + def __init__(self, devices, group_size, collective_keys=None): """Initializes the object. Args: @@ -1014,7 +998,6 @@ class CollectiveAllReduce(CrossDeviceOps): group_size: the global group size. For between-graph replicated training it's the total number of devices across all workers. collective_keys: an optional CollectiveKey object. - communication: indicates which collective communication to use. """ if group_size % len(devices) > 0: raise ValueError("group_size must be divisible by the number of devices.") @@ -1022,7 +1005,6 @@ class CollectiveAllReduce(CrossDeviceOps): self._group_size = group_size self._collective_keys = (collective_keys or cross_device_utils.CollectiveKeys()) - self._communication = communication # This lock guards all collective launches, i.e. calls to # cross_device_utils.build_collectve_*. # @@ -1062,9 +1044,10 @@ class CollectiveAllReduce(CrossDeviceOps): return self._group_size / len(self._devices) def reduce_implementation(self, reduce_op, per_replica_value, destinations, - experimental_hints): + options): + values_util.mark_as_unsaveable() all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value], - experimental_hints)[0] + options)[0] devices = get_devices_from(destinations) if _devices_match(per_replica_value, destinations): @@ -1093,12 +1076,13 @@ class CollectiveAllReduce(CrossDeviceOps): return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) def batch_reduce_implementation(self, reduce_op, value_destination_pairs, - experimental_hints): + options): + values_util.mark_as_unsaveable() all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: return self._batch_all_reduce(reduce_op, [v[0] for v in value_destination_pairs], - experimental_hints) + options) else: if not all_devices_match: logging.log_first_n( @@ -1106,43 +1090,41 @@ class CollectiveAllReduce(CrossDeviceOps): "destinations are different.", 10) return [ - self.reduce_implementation(reduce_op, value, dest, experimental_hints) + self.reduce_implementation(reduce_op, value, dest, options) for value, dest in value_destination_pairs ] - def _batch_all_reduce(self, reduce_op, per_replica_values, - experimental_hints): + def _batch_all_reduce(self, reduce_op, per_replica_values, options): """All reduce algorithm in a batch.""" dense_values, dense_indices, sparse_values, sparse_indices = ( cross_device_utils.split_by_sparsity(per_replica_values)) if dense_values: dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values, - experimental_hints) + options) else: dense_results = [] if sparse_values: sparse_results = self._do_batch_all_reduce_sparse(reduce_op, - sparse_values, - experimental_hints) + sparse_values, options) else: sparse_results = [] return cross_device_utils.stitch_values( ((dense_results, dense_indices), (sparse_results, sparse_indices))) - def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, - experimental_hints): + def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options): """All-reduce across all workers in a batch.""" batch_size = len(per_replica_values) - # Pass self._communication to the runtime as a communication hint. - communication = self._communication.value + implementation = options.implementation.value # For now, we use NCCL only when batch_size > 1 since we don't have a way to # order NCCL launches. We're hoping that there's only one batched # all-reduce, which is the gradients. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL if and only if we can order collectives deterministically. - if self._communication == CollectiveCommunication.NCCL and batch_size == 1: - communication = CollectiveCommunication.AUTO.value + # is NCCL. + if (options.implementation == CommunicationImplementation.NCCL and + batch_size == 1): + implementation = CommunicationImplementation.AUTO.value # Reverse the lists so that there's better chance that values follows # the order in which they are calculated (e.g. when they're gradients), so @@ -1163,15 +1145,15 @@ class CollectiveAllReduce(CrossDeviceOps): with self._lock: for i in range(len(self._devices)): packs = cross_device_utils.group_by_size( - values_by_device[i], experimental_hints.bytes_per_pack) + values_by_device[i], options.bytes_per_pack) if not context.executing_eagerly() and i == 0: logging.info( "Collective batch_all_reduce: %d all-reduces, num_devices = %d, " - "group_size = %d, communication_hint = %s, num_packs = %d", - batch_size, len(self._launchers), self._group_size, communication, - len(packs)) + "group_size = %d, implementation = %s, num_packs = %d", + batch_size, len(self._launchers), self._group_size, + implementation, len(packs)) outputs_by_device.append(self._launchers[i].batch_all_reduce( - packs, communication, experimental_hints.timeout_seconds)) + packs, implementation, options.timeout_seconds)) for e in self._executors: e.wait() @@ -1188,8 +1170,7 @@ class CollectiveAllReduce(CrossDeviceOps): # Reverse the order of reduced value to recover the order in the input. return list(reversed(mirrored)) - def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, - experimental_hints): + def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options): """All-reduce IndexedSlices across all workers in a batch.""" logging.log_first_n( @@ -1197,8 +1178,13 @@ class CollectiveAllReduce(CrossDeviceOps): "%d all-reduces, group_size = %d" % (len(per_replica_values), self._group_size), 10) - # Pass self._communication to the runtime as a communication hint. - communication_hint = self._communication.value + implementation = options.implementation.value + # For now, we use NCCL only when batch_size > 1. + # TODO(b/132575814): switch to NCCL for all collectives when implementation + # is NCCL. + if options.implementation == CommunicationImplementation.NCCL and len( + per_replica_values) == 1: + implementation = CommunicationImplementation.AUTO.value gathered_values = [] with self._lock: @@ -1206,8 +1192,7 @@ class CollectiveAllReduce(CrossDeviceOps): outputs = [] for i in range(len(self._devices)): outputs.append(self._launchers[i].all_reduce_indexed_slices( - per_replica.values[i], communication_hint, - experimental_hints.timeout_seconds)) + per_replica.values[i], implementation, options.timeout_seconds)) gathered_values.append(outputs) mirrored = [] @@ -1222,9 +1207,9 @@ class CollectiveAllReduce(CrossDeviceOps): return mirrored def _gather_implementation(self, per_replica_value, destinations, axis, - experimental_hints): - all_gathered = self._batch_all_gather([per_replica_value], axis, - experimental_hints)[0] + options): + all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0] + values_util.mark_as_unsaveable() devices = get_devices_from(destinations) if _devices_match(per_replica_value, destinations): @@ -1250,16 +1235,23 @@ class CollectiveAllReduce(CrossDeviceOps): index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) - def _batch_all_gather(self, per_replica_values, axis, experimental_hints): + def _batch_all_gather(self, per_replica_values, axis, options): """all gather multiple per-replica-values.""" batch_size = len(per_replica_values) - # Pass self._communication to the runtime as a communication hint. - communication = self._communication.value + # Pass options.implementation to the runtime as a communication + # implementation hint. + implementation = options.implementation.value + # For now, we use NCCL only when batch_size > 1. + # TODO(b/132575814): switch to NCCL for all collectives when implementation + # is NCCL. + if (options.implementation == CommunicationImplementation.NCCL and + batch_size == 1): + implementation = CommunicationImplementation.AUTO.value logging.log_first_n( logging.INFO, "Collective batch_all_gather: %d all-gathers, " - "num_devices = %d, group_size = %d, communication_hint = %s, " % - (batch_size, len(self._devices), self._group_size, communication), 10) + "num_devices = %d, group_size = %d, implementation = %s, " % + (batch_size, len(self._devices), self._group_size, implementation), 10) def compute_gathered_values(): gathered_values = [] @@ -1268,8 +1260,8 @@ class CollectiveAllReduce(CrossDeviceOps): outputs = [] for i in range(len(self._devices)): outputs.append(self._launchers[i].all_gather( - per_replica.values[i], axis, communication, - experimental_hints.timeout_seconds)) + per_replica.values[i], axis, implementation, + options.timeout_seconds)) gathered_values.append(outputs) return gathered_values @@ -1288,8 +1280,7 @@ class CollectiveAllReduce(CrossDeviceOps): # distribute_coordinator deep-copies the strategy object, so # CollectiveAllReduce needs to support deep copy as well. collective_keys = copy.deepcopy(self._collective_keys, memo) - return CollectiveAllReduce(self._devices, self._group_size, collective_keys, - self._communication) + return CollectiveAllReduce(self._devices, self._group_size, collective_keys) def select_cross_device_ops(devices, session_config=None): diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 461b32d57b0..191394f69af 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.distribute import cluster_resolver as cluster_resolver_li from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import reduce_util @@ -45,9 +46,11 @@ from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util import nest -CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication +CommunicationImplementation = collective_util.CommunicationImplementation ReduceOp = reduce_util.ReduceOp IndexedSlicesValue = indexed_slices.IndexedSlicesValue IndexedSlices = indexed_slices.IndexedSlices @@ -139,14 +142,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): global_mpr_1p.runner.run(enable_collective_ops) global_mpr_2p.runner.run(enable_collective_ops) - def make_collective(self, num_processes, gpu_per_process, communication): + def make_collective(self, num_processes, gpu_per_process): """Returns collectives and other info to be used in tests. Args: num_processes: an integer indicating the number of processes that participate in the collective. gpu_per_process: number of GPUs (0 if no GPUs) used by each process. - communication: one of `CollectiveCommunication`. Returns: A tuple of (collective, devices, group_size) where collective is a instance @@ -166,7 +168,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): ] group_size = num_processes * len(devices) collective = cross_device_ops_lib.CollectiveAllReduce( - devices=devices, group_size=group_size, communication=communication) + devices=devices, group_size=group_size) return collective, devices, cluster_resolver.task_id def as_list(self, value): @@ -201,10 +203,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): "num_processes", "gpus_per_process", "reduce_op", - "communication", + "communication_options", + "use_scoped_allocator", + "use_collective_v2", ]) - RunOptions.__new__.__defaults__ = (["eager", "func_graph"], 2, 0, - ReduceOp.SUM, CollectiveCommunication.AUTO) + RunOptions.__new__.__defaults__ = (["eager", + "func_graph"], 2, 0, ReduceOp.SUM, + collective_util.Options(), True, False) def reduce_and_verify(self, inputs, expect, options): """Reduce the given `inputs` and verify the output matches `expect`. @@ -218,15 +223,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): """ def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + options.use_collective_v2) collective, devices, pid = self.make_collective(options.num_processes, - options.gpus_per_process, - options.communication) + options.gpus_per_process) def reduce_fn(): value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx] per_replica_value = make_per_replica_value(value_fn, devices) reduced_values = collective.reduce(options.reduce_op, per_replica_value, - per_replica_value) + per_replica_value, + options.communication_options) reduced_values = self.as_list(reduced_values) self.assertAllEqual(devices, [v.device for v in reduced_values]) return [ops.convert_to_tensor(v) for v in reduced_values] @@ -255,9 +262,12 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): """ def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = ( + options.use_scoped_allocator) + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + options.use_collective_v2) collective, devices, pid = self.make_collective(options.num_processes, - options.gpus_per_process, - options.communication) + options.gpus_per_process) def batch_reduce_fn(): batch_size = len(inputs[0]) @@ -270,7 +280,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): per_replica_value = make_per_replica_value(value_fn, devices) value_dst_pairs.append((per_replica_value, per_replica_value)) reduced_values = collective.batch_reduce(options.reduce_op, - value_dst_pairs) + value_dst_pairs, + options.communication_options) reduced_values = [self.as_list(v) for v in reduced_values] for v in reduced_values: self.assertAllEqual(devices, [t.device for t in v]) @@ -293,20 +304,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=[1, 2], required_gpus=[0, 1, 2], - communication=[ + implementation=[ # NCCL is only used for batch reduce, so we are not including # NCCL combination here. - CollectiveCommunication.AUTO, - CollectiveCommunication.RING + CommunicationImplementation.AUTO, + CommunicationImplementation.RING ], - reduce_op=[ReduceOp.SUM, ReduceOp.MEAN])) - def testAllReduceDense(self, num_processes, required_gpus, communication, - reduce_op): + reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], + use_collective_v2=[True, False])) + def testAllReduceDense(self, num_processes, required_gpus, implementation, + reduce_op, use_collective_v2): options = self.RunOptions( num_processes=num_processes, gpus_per_process=required_gpus, reduce_op=reduce_op, - communication=communication) + communication_options=collective_util.Options( + implementation=implementation), + use_collective_v2=use_collective_v2) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [1.0, 2.0, 3.0, 4.0] @@ -325,22 +339,25 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=[1, 2], required_gpus=[0, 1, 2], - communication=[ + implementation=[ # NCCL is only used for batch reduce, so we are not including # NCCL combination here. - CollectiveCommunication.AUTO, - CollectiveCommunication.RING + CommunicationImplementation.AUTO, + CommunicationImplementation.RING ], # TODO(b/166682130): add MEAN reduce once the bug is fixed. - reduce_op=ReduceOp.SUM)) - def testAllReduceSparse(self, num_processes, required_gpus, communication, - reduce_op): + reduce_op=ReduceOp.SUM, + use_collective_v2=[True, False])) + def testAllReduceSparse(self, num_processes, required_gpus, implementation, + reduce_op, use_collective_v2): options = self.RunOptions( mode=["func_graph"], # Sparse reduce is not supported in eager. num_processes=num_processes, gpus_per_process=required_gpus, reduce_op=reduce_op, - communication=communication) + communication_options=collective_util.Options( + implementation=implementation), + use_collective_v2=use_collective_v2) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [ @@ -371,7 +388,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): self.reduce_and_verify(inputs, expect, options) - def testAllReduceSparseVariableLength(self): + @combinations.generate(combinations.combine(use_collective_v2=[True, False])) + def testAllReduceSparseVariableLength(self, use_collective_v2): # One device per process, 2 processes, 2 replicas in total. inputs = [ IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]), @@ -388,22 +406,28 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): self.RunOptions( mode=["func_graph"], # Sparse reduce is not supported in eager. num_processes=2, - reduce_op=ReduceOp.SUM)) + reduce_op=ReduceOp.SUM, + use_collective_v2=use_collective_v2)) @combinations.generate( combinations.combine( num_processes=[1, 2], required_gpus=[0, 1, 2], - communication=[ - CollectiveCommunication.AUTO, CollectiveCommunication.RING, - CollectiveCommunication.NCCL + implementation=[ + CommunicationImplementation.AUTO, + CommunicationImplementation.RING, CommunicationImplementation.NCCL ], - reduce_op=[ReduceOp.SUM, ReduceOp.MEAN])) - def testBatchAllReduceDense(self, num_processes, required_gpus, communication, - reduce_op): - if required_gpus == 0 and communication == CollectiveCommunication.NCCL: + reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], + use_scoped_allocator=[True, False], + use_collective_v2=[True, False])) + def testBatchAllReduceDense(self, num_processes, required_gpus, + implementation, reduce_op, use_scoped_allocator, + use_collective_v2): + if (required_gpus == 0 and + implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") - if num_processes == 2 and communication == CollectiveCommunication.NCCL: + if (num_processes == 2 and + implementation == CommunicationImplementation.NCCL): self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " "physical GPUs for every process.") @@ -411,7 +435,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): num_processes=num_processes, gpus_per_process=required_gpus, reduce_op=reduce_op, - communication=communication) + communication_options=collective_util.Options( + implementation=implementation), + use_scoped_allocator=use_scoped_allocator, + use_collective_v2=use_collective_v2) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] @@ -430,18 +457,23 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=[1, 2], required_gpus=[0, 1, 2], - communication=[ - CollectiveCommunication.AUTO, - CollectiveCommunication.RING, - CollectiveCommunication.NCCL, + implementation=[ + CommunicationImplementation.AUTO, + CommunicationImplementation.RING, + CommunicationImplementation.NCCL, ], # TODO(b/166682130): add MEAN reduce once the bug is fixed. - reduce_op=ReduceOp.SUM)) + reduce_op=ReduceOp.SUM, + use_scoped_allocator=[True, False], + use_collective_v2=[True, False])) def testBatchAllReduceSparse(self, num_processes, required_gpus, - communication, reduce_op): - if required_gpus == 0 and communication == CollectiveCommunication.NCCL: + implementation, reduce_op, use_scoped_allocator, + use_collective_v2): + if (required_gpus == 0 and + implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") - if num_processes == 2 and communication == CollectiveCommunication.NCCL: + if (num_processes == 2 and + implementation == CommunicationImplementation.NCCL): self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " "physical GPUs for every process.") @@ -450,7 +482,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): num_processes=num_processes, gpus_per_process=required_gpus, reduce_op=reduce_op, - communication=communication) + communication_options=collective_util.Options( + implementation=implementation), + use_scoped_allocator=use_scoped_allocator, + use_collective_v2=use_collective_v2) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = ([ @@ -513,24 +548,26 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): required_gpus=[0, 1, 2], axis=[0, 1, 2], func_mode=["eager", "func_graph"], - communication=[ - CollectiveCommunication.NCCL, - CollectiveCommunication.AUTO, - CollectiveCommunication.RING - ])) - def testAllGatherSameShape(self, num_processes, required_gpus, communication, - func_mode, axis): + implementation=[ + CommunicationImplementation.NCCL, + CommunicationImplementation.AUTO, CommunicationImplementation.RING + ], + use_collective_v2=[True, False])) + def testAllGatherSameShape(self, num_processes, required_gpus, implementation, + func_mode, axis, use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, _ = self.make_collective(num_processes, - required_gpus, - communication) + required_gpus) + options = collective_util.Options(implementation=implementation) value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32) def gather_fn(): per_replica_value = make_per_replica_value(value, devices) gathered_values = collective._gather( - per_replica_value, per_replica_value, axis=axis) + per_replica_value, per_replica_value, axis=axis, options=options) gathered_values = self.as_list(gathered_values) # Skip checking devices in eager. In eager the device attribute doesn't # reflect the actual device of the tensor. @@ -554,19 +591,54 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( - num_processes=1, - required_gpus=2, - communication=[ - CollectiveCommunication.NCCL, CollectiveCommunication.RING - ])) - def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, - required_gpus, - communication): + num_processes=[1, 2], + required_gpus=[0, 1, 2], + implementation=[CommunicationImplementation.RING])) + def testCollectiveV2ControlFlow(self, num_processes, required_gpus, + implementation): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True collective, devices, _ = self.make_collective(num_processes, + required_gpus) + options = collective_util.Options(implementation=implementation) + value = make_per_replica_value(constant_op.constant([1.]), devices) + + @def_function.function + def reduce_fn(): + + def cond_body(): + reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value, + options) + return math_ops.add_n(self.as_list(reduced)) / len(devices) + + return control_flow_ops.cond( + array_ops.identity(False), cond_body, cond_body) + + num_replicas = num_processes * len(devices) + self.assertAllEqual(reduce_fn(), [1. * num_replicas]) + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=1, + required_gpus=2, + implementation=[ + CommunicationImplementation.NCCL, CommunicationImplementation.RING + ], + use_collective_v2=[True, False])) + def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, required_gpus, - communication) + implementation, + use_collective_v2): + + def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) + collective, devices, _ = self.make_collective(num_processes, + required_gpus) + options = collective_util.Options(implementation=implementation) # We would like to simulate the following sequence: # thread-0 device0 device1 @@ -595,14 +667,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): def thread_fn(): reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, - [(v0, v0), (v0, v0)]) + [(v0, v0), (v0, v0)], options) self.assertAllEqual(reduced[0].values, [2.0, 2.0]) self.assertAllEqual(reduced[1].values, [2.0, 2.0]) t = threading.Thread(target=thread_fn) t.start() reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1), - (v1, v1)]) + (v1, v1)], + options) self.assertAllEqual(reduced[0].values, [4.0, 4.0]) self.assertAllEqual(reduced[1].values, [4.0, 4.0]) t.join() @@ -613,16 +686,19 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=1, required_gpus=2, - communication=[ - CollectiveCommunication.NCCL, CollectiveCommunication.RING - ])) + implementation=[ + CommunicationImplementation.NCCL, CommunicationImplementation.RING + ], + use_collective_v2=[True, False])) def testInputsAreFunctionArgs(self, num_processes, required_gpus, - communication): + implementation, use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, _ = self.make_collective(num_processes, - required_gpus, - communication) + required_gpus) + options = collective_util.Options(implementation=implementation) @def_function.function def reduce_fn(v): @@ -632,7 +708,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): # We only use NCCL for batch reduce with two or more values, so we use # two values here. reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), - (v, v)]) + (v, v)], + options) self.assertEqual(reduced[0].values[0].device, devices[0]) self.assertEqual(reduced[0].values[1].device, devices[1]) self.assertEqual(reduced[1].values[0].device, devices[0]) @@ -651,21 +728,26 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=2, required_gpus=[0, 1], - communication=[CollectiveCommunication.RING])) - def testTimeoutReduceDense(self, num_processes, communication, required_gpus): + implementation=[CommunicationImplementation.RING], + use_collective_v2=[True, False])) + def testTimeoutReduceDense(self, num_processes, implementation, required_gpus, + use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, task_id = self.make_collective( - num_processes, required_gpus, communication) + num_processes, required_gpus) if task_id != 0: return v = make_per_replica_value(1.0, devices) - hints = collective_util.Hints(timeout_seconds=1) + options = collective_util.Options( + timeout_seconds=1, implementation=implementation) @def_function.function def reduce_dense(): - collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) + return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) # The collective should time out because we only launch it on worker-0, # while there're three workers in total. @@ -678,23 +760,27 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=2, required_gpus=[0, 1], - communication=[CollectiveCommunication.RING])) - def testTimeoutBatchReduceDense(self, num_processes, communication, - required_gpus): + implementation=[CommunicationImplementation.RING], + use_collective_v2=[True, False])) + def testTimeoutBatchReduceDense(self, num_processes, implementation, + required_gpus, use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, task_id = self.make_collective( - num_processes, required_gpus, communication) + num_processes, required_gpus) if task_id != 0: return v = make_per_replica_value(1.0, devices) - hints = collective_util.Hints(timeout_seconds=1) + options = collective_util.Options( + timeout_seconds=1, implementation=implementation) @def_function.function def batch_reduce_dense(): - collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], - hints) + return collective.batch_reduce(reduce_util.ReduceOp.SUM, + [(v, v), (v, v)], options) # The collective should time out because we only launch it on worker-0, # while there're two workers in total. @@ -707,24 +793,28 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=2, required_gpus=[0, 1], - communication=[CollectiveCommunication.RING])) - def testTimeoutReduceSparse(self, num_processes, communication, - required_gpus): + implementation=[CommunicationImplementation.RING], + use_collective_v2=[True, False])) + def testTimeoutReduceSparse(self, num_processes, implementation, + required_gpus, use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, task_id = self.make_collective( - num_processes, required_gpus, communication) + num_processes, required_gpus) if task_id != 0: return v = make_per_replica_value( IndexedSlicesValue( values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) - hints = collective_util.Hints(timeout_seconds=1) + options = collective_util.Options( + timeout_seconds=1, implementation=implementation) @def_function.function def reduce_sparse(): - collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) + return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) # The collective should time out because we only launch it on worker-0, # while there're two workers in total. @@ -737,25 +827,29 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): combinations.combine( num_processes=2, required_gpus=[0, 1], - communication=[CollectiveCommunication.RING])) + implementation=[CommunicationImplementation.RING], + use_collective_v2=[True, False])) def testTimeoutBatchReduceSparse(self, num_processes, required_gpus, - communication): + implementation, use_collective_v2): def replica_fn(): + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = ( + use_collective_v2) collective, devices, task_id = self.make_collective( - num_processes, required_gpus, communication) + num_processes, required_gpus) if task_id != 0: return v = make_per_replica_value( IndexedSlicesValue( values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) - hints = collective_util.Hints(timeout_seconds=1) + options = collective_util.Options( + timeout_seconds=1, implementation=implementation) @def_function.function def batch_reduce_sparse(): - collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], - hints) + return collective.batch_reduce(reduce_util.ReduceOp.SUM, + [(v, v), (v, v)], options) # The collective should time out because we only launch it on worker-0, # while there're two workers in total. diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 3920c2af6b0..96866fb1ca4 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_ops @@ -255,6 +256,9 @@ class CollectiveKeys(object): class CollectiveReplicaLauncher(object): """Launch collectives on one replica.""" + _use_scoped_allocator = True + _use_collective_v2 = False + def __init__(self, group_key, group_size, @@ -281,6 +285,44 @@ class CollectiveReplicaLauncher(object): return ops.control_dependencies([control_input]) return ops.NullContextmanager() + def _should_use_collective_v2(self): + if not CollectiveReplicaLauncher._use_collective_v2: + return False + if not ops.executing_eagerly_outside_functions(): + return False + return True + + def _next_instance_key(self): + """Returns the next instance key.""" + if self._should_use_collective_v2(): + # Assigning instance keys at function building time have issues since + # different workers may retrace the function at different times. With + # collective V2 we can use capture_call_time_value to use a placeholder as + # the instance key and feed it at function call time. In this way we also + # don't reuse instance keys, which allows for per-instance cancellation. + graph = ops.get_default_graph() + # Control flow ops don't work with capture_call_time_value, so we put the + # capture in the function graph of that control flow op. + while getattr(graph, 'is_control_flow_graph', False): + graph = graph.outer_graph + if not context.executing_eagerly() and graph.building_function: + with graph.as_default(): + # Capture self._next_instance_key so that when building a function + # that calls another tf.function, the instance key assignment is + # further delayed until we actually call the function in eager. Note + # that capture_call_time_value doesn't automatically propagate the + # deferred capture to the outer function. + return graph.capture_call_time_value( + self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) + else: + instance_key = self._collective_keys.get_instance_key( + self._group_key, self._device) + with ops.device('CPU:0'): + return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) + else: + return self._collective_keys.get_instance_key(self._group_key, + self._device) + def all_reduce(self, input_tensor, control_input=None, @@ -302,18 +344,60 @@ class CollectiveReplicaLauncher(object): Returns: The reduced tensor. """ - instance_key = self._collective_keys.get_instance_key( - self._group_key, self._device) + instance_key = self._next_instance_key() with self._executor_scope(), \ ops.device(self._device), \ self._control_input(control_input): - return collective_ops.all_reduce( - input_tensor, - self._group_size, - self._group_key, - instance_key, - communication_hint=communication_hint, - timeout=timeout) + if self._should_use_collective_v2(): + return collective_ops.all_reduce_v2( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout) + else: + return collective_ops.all_reduce( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout) + + def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0): + """All-gather a dense tensor. + + This can be called in eager mode if an async executor is supplied when + creating the launcher. + + Args: + input_tensor: a dense tensor. It must have the same shape on all replicas. + communication_hint: string providing hint to runtime for choosing + collective implementation. + timeout: a float. The timeout in seconds. + + Returns: + The reduced tensor. + """ + instance_key = self._next_instance_key() + with self._executor_scope(), ops.device(self._device): + if self._should_use_collective_v2(): + return collective_ops.all_gather_v2( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout) + else: + return collective_ops.all_gather( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout) def batch_all_reduce(self, input_tensor_packs, @@ -337,22 +421,45 @@ class CollectiveReplicaLauncher(object): Returns: A flat list of reduced tensors. """ + # We don't batch with concat in eager. It's easy to get it wrong because + # we need to avoid any numpy() calls on values produced by the async + # executor. This effectively disables batching in eager, but it's unlikely + # to all-reduce a large number of tensors in eager. + batch_with_concat = (not self._use_scoped_allocator and + not context.executing_eagerly()) outputs = [] for pack in input_tensor_packs: - # By placing all CollectiveReduce ops in a batch under single name scope, - # we ensure they will be picked up by the `ScopedAllocator` grappler - # optimizer and packed into a single all-reduce. - with ops.name_scope('allreduce'): - # TODO(b/169168846): inserts a parallel all_gather to verify packings - # are the same on each replica. - for input_tensor in pack: + # TODO(b/169168846): inserts a parallel all_gather to verify packings + # are the same on each replica. + if batch_with_concat: + with ops.device(self._device): + flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] + shapes = [array_ops.shape(t) for t in pack] if communication_hint == 'NCCL' and outputs: control_input = outputs[-1] else: control_input = None - outputs.append( - self.all_reduce(input_tensor, control_input, communication_hint, - timeout)) + reduced = self.all_reduce( + array_ops.concat(flat_tensors, axis=0), control_input, + communication_hint, timeout) + num_elements = [math_ops.reduce_prod(s) for s in shapes] + flat_outputs = array_ops.split(reduced, num_elements, axis=0) + for shape, flat_output in zip(shapes, flat_outputs): + outputs.append(array_ops.reshape(flat_output, shape)) + else: + # By placing all CollectiveReduce ops in a batch under single name + # scope, we ensure they will be picked up by the `ScopedAllocator` + # grappler optimizer and packed into a single all-reduce. + with ops.name_scope('allreduce'): + for input_tensor in pack: + if communication_hint == 'NCCL' and outputs: + control_input = outputs[-1] + else: + control_input = None + outputs.append( + self.all_reduce(input_tensor, control_input, communication_hint, + timeout)) + return outputs def all_gather(self, @@ -383,11 +490,8 @@ class CollectiveReplicaLauncher(object): if context.executing_eagerly(): raise RuntimeError('all_gather in eager mode is not supported') - instance_key_tensor = self._collective_keys.get_instance_key( - self._group_key, self._device) - instance_key_shape = self._collective_keys.get_instance_key( - self._group_key, self._device) - with ops.device(self._device): + with ops.device(self._device), \ + ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which @@ -399,11 +503,8 @@ class CollectiveReplicaLauncher(object): axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad - gathered_shape = collective_ops.all_gather( + gathered_shape = self._all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), - self._group_size, - self._group_key, - instance_key_shape, communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] @@ -411,16 +512,11 @@ class CollectiveReplicaLauncher(object): padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather - gather_padded_out_tensor = collective_ops.all_gather( - padded_input_tensor, - self._group_size, - self._group_key, - instance_key_tensor, - communication_hint, - timeout=timeout) + gather_padded_out_tensor = self._all_gather( + padded_input_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] - for i in range(first_dims.shape[0]): + for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) @@ -457,15 +553,6 @@ class CollectiveReplicaLauncher(object): raise RuntimeError( 'all_reduce_indexed_slices in eager mode is not supported') - gather_length_key = self._collective_keys.get_instance_key( - self._group_key, self._device) - gather_indices_key = self._collective_keys.get_instance_key( - self._group_key, self._device) - gather_values_key = self._collective_keys.get_instance_key( - self._group_key, self._device) - reduce_densified_key = self._collective_keys.get_instance_key( - self._group_key, self._device) - # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: @@ -477,23 +564,13 @@ class CollectiveReplicaLauncher(object): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" - all_values = collective_ops.all_gather( - input_slices.values, - self._group_size, - self._group_key, - gather_values_key, - communication_hint, - timeout=timeout) + all_values = self._all_gather( + input_slices.values, communication_hint, timeout=timeout) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): - all_indices = collective_ops.all_gather( - input_slices.indices, - self._group_size, - self._group_key, - gather_indices_key, - communication_hint, - timeout=timeout) + all_indices = self._all_gather( + input_slices.indices, communication_hint, timeout=timeout) return ops.IndexedSlices( values=all_values, indices=all_indices, @@ -502,15 +579,8 @@ class CollectiveReplicaLauncher(object): def densify_and_all_reduce(): """Use all_reduce to aggregate `IndexedSlices`.""" densified = ops.convert_to_tensor(input_slices) - reduced = collective_ops.all_reduce( - densified, - self._group_size, - self._group_key, - reduce_densified_key, - 'Add', - 'Id', [0], - communication_hint, - timeout=timeout) + reduced = self.all_reduce( + densified, communication_hint=communication_hint, timeout=timeout) # We have to convert dense grad to IndexedSlice because all_reduce() # and all_gather() must have the same return type as required by # control_flow_ops.cond. @@ -520,13 +590,8 @@ class CollectiveReplicaLauncher(object): dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) - all_lengths = collective_ops.all_gather( - length, - self._group_size, - self._group_key, - gather_length_key, - communication_hint, - timeout=timeout) + all_lengths = self._all_gather( + length, communication_hint, timeout=timeout) return control_flow_ops.cond( math_ops.equal( math_ops.reduce_max(all_lengths), diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 7cd5eb1c85a..de55b639f6b 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -295,7 +295,8 @@ def get_loss_reduction(): # Internal API for validating the current thread mode -def _require_cross_replica_or_default_context_extended(extended): +def _require_cross_replica_or_default_context_extended(extended, + error_message=None): """Verify in cross-replica context.""" context = _get_per_thread_mode() cross_replica = context.cross_replica_context @@ -308,8 +309,10 @@ def _require_cross_replica_or_default_context_extended(extended): if context.strategy is not strategy: _wrong_strategy_scope(strategy, context) assert cross_replica is None - raise RuntimeError("Method requires being in cross-replica context, use " + if not error_message: + error_message = ("Method requires being in cross-replica context, use " "get_replica_context().merge_call()") + raise RuntimeError(error_message) def _wrong_strategy_scope(strategy, context): @@ -439,8 +442,12 @@ class InputReplicationMode(enum.Enum): Replicas will dequeue from the local Dataset on their worker. `tf.distribute.Strategy` doesn't manage any state sharing between such separate input pipelines. + * `PER_REPLICA`: The input function will be called on each replica seperately. + `tf.distribute.Strategy` doesn't manage any state sharing between such + separate input pipelines. """ PER_WORKER = "PER_WORKER" + PER_REPLICA = "PER_REPLICA" @tf_export("distribute.InputContext") @@ -616,6 +623,8 @@ class RunOptions( class InputOptions( collections.namedtuple("InputOptions", [ "experimental_prefetch_to_device", + "experimental_replication_mode", + "experimental_place_dataset_on_device", ])): """Run options for `experimental_distribute_dataset(s_from_function)`. @@ -633,19 +642,36 @@ class InputOptions( strategy.experimental_distribute_dataset( dataset, tf.distribute.InputOptions( - experimental_prefetch_to_device=False))) + experimental_replication_mode= + experimental_replication_mode.PER_WORKER, + experimental_place_dataset_on_device=False))) ``` Attributes: experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset elements will be prefetched to accelerator device memory. When False, dataset elements are prefetched to host device memory. Must be False when - using TPUEmbedding API. + using TPUEmbedding API. experimental_prefetch_to_device can only be used + with experimental_replication_mode=PER_WORKER + experimental_replication_mode: Replication mode for the input function. + Currently, the InputReplicationMode.PER_REPLICA is only supported with + tf.distribute.MirroredStrategy. + experimental_distribute_datasets_from_function. + The default value is InputReplicationMode.PER_WORKER. + experimental_place_dataset_on_device: Boolean. Default to False. When True, + dataset will be placed on the device, otherwise it will remain on the + host. experimental_place_dataset_on_device=True can only be used with + experimental_replication_mode=PER_REPLICA """ - def __new__(cls, experimental_prefetch_to_device=True): - return super(InputOptions, cls).__new__(cls, - experimental_prefetch_to_device) + def __new__(cls, + experimental_prefetch_to_device=True, + experimental_replication_mode=InputReplicationMode.PER_WORKER, + experimental_place_dataset_on_device=False): + return super(InputOptions, + cls).__new__(cls, experimental_prefetch_to_device, + experimental_replication_mode, + experimental_place_dataset_on_device) # ------------------------------------------------------------------------------ # Base classes for all distribution strategies. @@ -1431,107 +1457,6 @@ class StrategyBase(object): denom = math_ops.cast(denom, numer.dtype) return math_ops.truediv(numer, denom) - # TODO(wxinyi): generate docs after it is implemented for all strategies. - # TODO(wxinyi): hide from V1 API - def _gather(self, value, axis): - # pylint: disable=line-too-long, protected-access - """Gather `value` across replicas along `axis` to the current device. - - Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like - object `value`, this API gathers and concatenates `value` along the - `axis`-th dimension. The result is copied to the "current" device - which - would typically be the CPU of the worker on which the program is running. - For `tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client - `MultiWorkerMirroredStrategy`, this is CPU of each worker. - - This API can only be called in the cross-replica context. For a counterpart - in the replica context, see `tf.distribute.ReplicaContext.all_gather`. - - Note: the input `value` on different replicas must have the same rank, and - they must have shapes that are consistent along all dimensions except the - `axis`-th dimension. For example, given a `tf.distribute.DistributedValues` - with tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can - call `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or - `gather(..., axis=2, ...)`. - - - # TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) - local_tensor = tf.constant([[1, 2], [3, 4]]) - distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(local_tensor)) - @tf.function - def run(): - return strategy.gather(distributed_values, axis=0) - run() - # - ``` - - Some more example cases: - - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) - local_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) - distributed_values = strategy.experimental_distribute_values_from_function(lambda _: local_tensor) - @tf.function - def run(): - return strategy.gather(distributed_values, axis=AXIS) - run() - - # With AXIS=0, the result is - # - # With AXIS=1, the result is - # - # With AXIS=2, the result is - # - - ``` - - Args: - value: a `tf.distribute.DistributedValues` instance, e.g. returned by - `Strategy.run`, to be combined into a single tensor. It can also be a - regular tensor when used with `OneDeviceStrategy` or default strategy. - The underlying tensor constructs can only be dense tensors with non-zero - rank, NOT `tf.IndexedSlices`. - axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the - range [0, rank(value)). - - Returns: - A `Tensor` that's the concatenation of `value` across replicas along - `axis` dimension. - """ - # pylint: enable=line-too-long - _require_cross_replica_or_default_context_extended(self._extended) - dst = device_util.current( - ) or self._extended._default_device or "/device:CPU:0" - if isinstance(value, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") - return self._extended._local_results( - self._extended._gather_to(value, dst, axis))[0] - @doc_controls.do_not_doc_inheritable # DEPRECATED def unwrap(self, value): """Returns the list of all local per-replica values contained in `value`. @@ -1764,6 +1689,112 @@ class Strategy(StrategyBase): return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access value_fn) + def gather(self, value, axis): + # pylint: disable=line-too-long, protected-access + """Gather `value` across replicas along `axis` to the current device. + + Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like + object `value`, this API gathers and concatenates `value` across replicas + along the `axis`-th dimension. The result is copied to the "current" device + - which would typically be the CPU of the worker on which the program is + running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For + multi-client `MultiWorkerMirroredStrategy`, this is CPU of each worker. + + This API can only be called in the cross-replica context. For a counterpart + in the replica context, see `tf.distribute.ReplicaContext.all_gather`. + + Note: For all strategies except `tf.distribute.TPUStrategy`, the input + `value` on different replicas must have the same rank, and their shapes must + be the same in all dimensions except the `axis`-th dimension. In other + words, their shapes cannot be different in a dimension `d` where `d` does + not equal to the `axis` argument. For example, given a + `tf.distribute.DistributedValues` with component tensors of shape + `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call + `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or + `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, + all tensors must have exactly the same rank and same shape. + + Note: Given a `tf.distribute.DistributedValues` `value`, its component + tensors must have a non-zero rank. Otherwise, consider using + `tf.expand_dims` before gathering them. + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> # A DistributedValues with component tensor of shape (2, 1) on each replica + ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) + >>> @tf.function + ... def run(): + ... return strategy.gather(distributed_values, axis=0) + >>> run() + + + + Consider the following example for more combinations: + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) + >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) + >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) + >>> @tf.function + ... def run(axis): + ... return strategy.gather(distributed_values, axis=axis) + >>> axis=0 + >>> run(axis) + + >>> axis=1 + >>> run(axis) + + >>> axis=2 + >>> run(axis) + + + + Args: + value: a `tf.distribute.DistributedValues` instance, e.g. returned by + `Strategy.run`, to be combined into a single tensor. It can also be a + regular tensor when used with `tf.distribute.OneDeviceStrategy` or the + default strategy. The tensors that constitute the DistributedValues + can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. + axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the + range [0, rank(value)). + + Returns: + A `Tensor` that's the concatenation of `value` across replicas along + `axis` dimension. + """ + # pylint: enable=line-too-long + error_message = ("tf.distribute.Strategy.gather method requires " + "cross-replica context, use " + "get_replica_context().all_gather() instead") + _require_cross_replica_or_default_context_extended(self._extended, + error_message) + dst = device_util.current( + ) or self._extended._default_device or "/device:CPU:0" + if isinstance(value, ops.IndexedSlices): + raise NotImplementedError("gather does not support IndexedSlices") + return self._extended._local_results( + self._extended._gather_to(value, dst, axis))[0] + # TF v1.x version has additional deprecated APIs @tf_export(v1=["distribute.Strategy"]) @@ -2182,7 +2213,7 @@ class StrategyExtendedV2(object): dst = device_util.current() or self._default_device or "/device:CPU:0" return self._local_results(self.reduce_to(reduce_op, value, dst))[0] - def reduce_to(self, reduce_op, value, destinations, experimental_hints=None): + def reduce_to(self, reduce_op, value, destinations, options=None): """Combine (via e.g. sum or mean) values across replicas. `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed @@ -2247,12 +2278,17 @@ class StrategyExtendedV2(object): `destinations`. Note that if it's a `tf.Variable`, the value is reduced to the devices of that variable, and this method doesn't update the variable. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. See + `tf.distribute.experimental.CommunicationOptions` for details of the + options. Returns: A tensor or value reduced to `destinations`. """ + if options is None: + options = collective_util.Options() _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) assert not isinstance(reduce_op, variable_scope.VariableAggregation) @@ -2260,17 +2296,12 @@ class StrategyExtendedV2(object): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) assert (reduce_op == reduce_util.ReduceOp.SUM or reduce_op == reduce_util.ReduceOp.MEAN) - if experimental_hints is None: - experimental_hints = collective_util.Hints() - return self._reduce_to(reduce_op, value, destinations, experimental_hints) + return self._reduce_to(reduce_op, value, destinations, options) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): raise NotImplementedError("must be implemented in descendants") - def batch_reduce_to(self, - reduce_op, - value_destination_pairs, - experimental_hints=None): + def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): """Combine multiple `reduce_to` calls into one for faster execution. Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. @@ -2325,30 +2356,30 @@ class StrategyExtendedV2(object): "SUM", "MEAN". value_destination_pairs: a sequence of (value, destinations) pairs. See `tf.distribute.Strategy.reduce_to` for descriptions. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. See + `tf.distribute.experimental.CommunicationOptions` for details of the + options. Returns: A list of reduced values, one per pair in `value_destination_pairs`. """ + if options is None: + options = collective_util.Options() _require_cross_replica_or_default_context_extended(self) assert not isinstance(reduce_op, variable_scope.VariableAggregation) if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) - if experimental_hints is None: - experimental_hints = collective_util.Hints() - return self._batch_reduce_to(reduce_op, value_destination_pairs, - experimental_hints) + return self._batch_reduce_to(reduce_op, value_destination_pairs, options) - def _batch_reduce_to(self, reduce_op, value_destination_pairs, - experimental_hints): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): return [ - self.reduce_to( - reduce_op, t, destinations=v, experimental_hints=experimental_hints) + self.reduce_to(reduce_op, t, destinations=v, options=options) for t, v in value_destination_pairs ] - def _gather_to(self, value, destinations, axis, experimental_hints=None): + def _gather_to(self, value, destinations, axis, options=None): """Gather `value` across replicas along axis-th dimension to `destinations`. `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like @@ -2365,31 +2396,30 @@ class StrategyExtendedV2(object): variable. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See - `tf.distribute.experimental.CollectiveHints` for details. + options: a `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. See + `tf.distribute.experimental.CommunicationOptions` for details of the + options. Returns: A tensor or value gathered to `destinations`. """ _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) - if experimental_hints is None: - experimental_hints = collective_util.Hints() - return self._gather_to_implementation(value, destinations, axis, experimental_hints) + if options is None: + options = collective_util.Options() + return self._gather_to_implementation(value, destinations, axis, options) - def _gather_to_implementation(self, value, destinations, axis, experimental_hints): + def _gather_to_implementation(self, value, destinations, axis, options): raise NotImplementedError("_gather_to must be implemented in descendants") - def _batch_gather_to(self, - value_destination_pairs, - axis, - experimental_hints=None): + def _batch_gather_to(self, value_destination_pairs, axis, options=None): _require_cross_replica_or_default_context_extended(self) - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if options is None: + options = collective_util.Options() return [ - self._gather_to( - t, destinations=v, axis=axis, experimental_hints=experimental_hints) + self._gather_to(t, destinations=v, axis=axis, options=options) for t, v in value_destination_pairs ] @@ -2410,7 +2440,8 @@ class StrategyExtendedV2(object): Example usage: ```python - strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices + strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 + devices with strategy.scope(): v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) def update_fn(v): @@ -2811,8 +2842,7 @@ class StrategyExtendedV1(StrategyExtendedV2): # It sets the current Strategy for purposes of # `get_strategy()` and `has_strategy()` # and switches the thread mode to a "cross-replica context". -@tf_export("distribute.ReplicaContext") -class ReplicaContext(object): +class ReplicaContextBase(object): """A class with a collection of APIs that can be called in a replica context. You can use `tf.distribute.get_replica_context` to get an instance of @@ -2905,6 +2935,7 @@ class ReplicaContext(object): require_replica_context(self) if kwargs is None: kwargs = {} + merge_fn = autograph.tf_convert( merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) return self._merge_call(merge_fn, args, kwargs) @@ -2975,7 +3006,7 @@ class ReplicaContext(object): require_replica_context(self) return (device_util.current(),) - def all_reduce(self, reduce_op, value, experimental_hints=None): + def all_reduce(self, reduce_op, value, options=None): """All-reduces `value` across all replicas. >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) @@ -2988,7 +3019,7 @@ class ReplicaContext(object): ) It supports batched operations. You can pass a list of values and it - attempts to batch them when possible. You can also specify `experimental_hints` + attempts to batch them when possible. You can also specify `options` to indicate the desired batching behavior, e.g. batch the values into multiple packs so that they can better overlap with computations. @@ -3028,8 +3059,11 @@ class ReplicaContext(object): value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts. The structure and the shapes of the `tf.Tensor` need to be same on all replicas. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints - to perform collective operations. + options: a `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. See + `tf.distribute.experimental.CommunicationOptions` for details of the + options. Returns: A nested structure of `tf.Tensor` with the reduced values. The structure @@ -3037,13 +3071,13 @@ class ReplicaContext(object): """ if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if options is None: + options = collective_util.Options() def batch_all_reduce(strategy, *value_flat): return strategy.extended.batch_reduce_to( reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], - experimental_hints) + options) if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. @@ -3069,77 +3103,118 @@ class ReplicaContext(object): # to that point that the first result is needed. Most likely this can be # implemented in terms of `merge_call()` and `batch_reduce_to()`. - # TODO(wxinyi): generate docs after it is implemented for all strategies. - def _all_gather(self, value, axis, experimental_hints=None): + +@tf_export("distribute.ReplicaContext", v1=[]) +class ReplicaContext(ReplicaContextBase): + + __doc__ = ReplicaContextBase.__doc__ + + def all_gather(self, value, axis, options=None): """All-gathers `value` across all replicas along `axis`. - Note: An `all_gather` method can only be called in replica context. To find + Note: An `all_gather` method can only be called in replica context. For a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. All replicas need to participate in the all-gather, otherwise this operation hangs. So if `all_gather` is called in any replica, it must be called in all replicas. - Note: If there're multiple all-gather calls, they need to execute in - the same order on all replicas. Dispatching all-gather based on conditions + Note: If there are multiple `all_gather` calls, they need to be executed in + the same order on all replicas. Dispatching `all_gather` based on conditions is usually error-prone. - # TODO(wxinyi): convert to testable docstring after implemented for MirroredStrategy - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"]) - @tf.function - def gather_value(): - ctx = tf.distribute.get_replica_context() - value = tf.constant([1, 2, 3]) - # all_gather a `tf.distribute.DistributedValues` - return strategy.run(ctx.all_gather(value, axis=0)) - strategy.experimental_local_results(gather_value) - # Result: - # (, - # ) - ``` + For all strategies except `tf.distribute.TPUStrategy`, the input + `value` on different replicas must have the same rank, and their shapes must + be the same in all dimensions except the `axis`-th dimension. In other + words, their shapes cannot be different in a dimension `d` where `d` does + not equal to the `axis` argument. For example, given a + `tf.distribute.DistributedValues` with component tensors of shape + `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call + `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` + or `all_gather(..., axis=2, ...)`. However, with + `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and + same shape. - ```python - strategy = tf.distribute.MirroredStrategy(["GPU:0", "CPU:0"]) - @tf.function - def gather_nest(): - ctx = tf.distribute.get_replica_context() - value_1 = tf.constant([1, 2, 3]) - value_2 = tf.constant([[1, 2], [3, 4]]) - # all_gather a nest of `tf.distribute.DistributedValues` - return ctx.all_gather([value_1, value_2], axis=0) - strategy.experimental_local_results(gather_nest) - # Result: - # ([, , >> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> @tf.function + ... def gather_value(): + ... ctx = tf.distribute.get_replica_context() + ... local_value = tf.constant([1, 2, 3]) + ... return ctx.all_gather(local_value, axis=0) + >>> result = strategy.run(gather_value) + >>> result + PerReplica:{ + 0: , + 1: + } + >>> strategy.experimental_local_results(result) + (, + ) + + + You can also pass in a nested structure of tensors to all-gather, say, a + list: + + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) + >>> @tf.function + ... def gather_nest(): + ... ctx = tf.distribute.get_replica_context() + ... value_1 = tf.constant([1, 2, 3]) + ... value_2 = tf.constant([[1, 2], [3, 4]]) + ... # all_gather a nest of `tf.distribute.DistributedValues` + ... return ctx.all_gather([value_1, value_2], axis=0) + >>> result = strategy.run(gather_nest) + >>> result + [PerReplica:{ + 0: , + 1: + }, PerReplica:{ + 0: , + 1: + }] + >>> strategy.experimental_local_results(result) + ([PerReplica:{ + 0: , + 1: + }, PerReplica:{ + 0: , + 1: + }],) + + + What if you are all-gathering tensors with different shapes on different + replicas? Consider the following example with two replicas, where you have + `value` as a nested structure consisting of two items to all-gather, `a` and + `b`. + + On Replica 0, `value` is {'a': [0], 'b': [[0, 1]]} + On Replica 1, `value` is {'a': [1], 'b': [[2, 3], [4, 5]]} + + Result for `all_gather` with `axis`=0: (on each of the replicas): {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]} - Note: an input to be all_gathered must have the same rank on different - replicas, and they must have shapes that are consistent along all dimensions - except the `axis`-th dimension. For example, given a - `tf.distribute.DistributedValues` with tensors of shape `(1, 2, 3)` and - `(1, 3, 3)` on two replicas, you can call `all_gather(..., axis=1, ...)` on - it, but not `all_gather(..., axis=0, ...)` or `all_gather(..., axis=2, ...)`. - - Args: value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, or a `tf.distribute.DistributedValues` instance. The structure of the @@ -3147,8 +3222,11 @@ class ReplicaContext(object): constructs can only be dense tensors with non-zero rank, NOT `tf.IndexedSlices`. axis: 0-D int32 Tensor. Dimension along which to gather. - experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints - to perform collective operations. + options: a `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. See + `tf.distribute.experimental.CommunicationOptions` for details of the + options. Returns: A nested structure of `tf.Tensor` with the gathered values. The structure @@ -3156,26 +3234,45 @@ class ReplicaContext(object): """ for v in nest.flatten(value): if isinstance(v, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") + raise NotImplementedError("all_gather does not support IndexedSlices") - if experimental_hints is None: - experimental_hints = collective_util.Hints() + if options is None: + options = collective_util.Options() def batch_all_gather(strategy, *value_flat): return strategy.extended._batch_gather_to( # pylint: disable=protected-access [(v, _batch_reduce_destination(v)) for v in value_flat], axis, - experimental_hints) + options) @custom_gradient.custom_gradient def grad_wrapper(*xs): ys = self.merge_call(batch_all_gather, args=xs) - # The gradient of an all-gather is itself an all-gather. - return ys, lambda *dy_s: self._all_gather(dy_s, axis) + + def grad(*dy_s): + grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) + new_grads = [] + for i, grad in enumerate(grads): + input_shape = array_ops.shape(xs[i]) + axis_dim = array_ops.reshape(input_shape[axis], [1]) + with ops.control_dependencies([array_ops.identity(grads)]): + d = self.all_gather(axis_dim, axis=0) + begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) + end_dim = begin_dim + array_ops.shape(xs[i])[axis] + new_grad = array_ops.gather( + grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) + new_grads.append(new_grad) + return new_grads + + return ys, grad return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) +@tf_export(v1=["distribute.ReplicaContext"]) +class ReplicaContextV1(ReplicaContextBase): + __doc__ = ReplicaContextBase.__doc__ + + def _batch_reduce_destination(x): """Returns the destinations for batch all-reduce.""" if isinstance(x, ops.Tensor): @@ -3319,13 +3416,13 @@ class _DefaultDistributionExtended(StrategyExtendedV1): with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): return fn(*args, **kwargs) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): # TODO(josh11b): Use destinations? - del reduce_op, destinations, experimental_hints + del reduce_op, destinations, options return value - def _gather_to_implementation(self, value, destinations, axis, experimental_hints): - del destinations, axis, experimental_hints + def _gather_to_implementation(self, value, destinations, axis, options): + del destinations, axis, options return value def _update(self, var, fn, args, kwargs, group): diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 04a1c5dc903..7533b3a35ca 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -94,8 +94,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): def _local_results(self, value): return (value,) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): - del reduce_op, destinations, experimental_hints + def _reduce_to(self, reduce_op, value, destinations, options): + del reduce_op, destinations, options return value def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index d01cedcead0..d4a69846695 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -25,6 +25,7 @@ import six from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops @@ -35,6 +36,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.distribute_lib import InputReplicationMode from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op @@ -108,7 +110,8 @@ def get_distributed_dataset(dataset, def get_distributed_datasets_from_function(dataset_fn, input_workers, input_contexts, - strategy): + strategy, + options=None): """Returns a distributed dataset from the given input function. This is a common function that is used by all strategies to return a @@ -126,22 +129,43 @@ def get_distributed_datasets_from_function(dataset_fn, `worker_device_pairs`. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. + options: Default is None. `tf.distribute.InputOptions` used to control + options on how this dataset is distributed. Returns: A distributed dataset instance. + + Raises: + ValueError: if `options.experimental_replication_mode` and + `options.experimental_place_dataset_on_device` are not consistent """ + if (options is not None and + options.experimental_replication_mode != InputReplicationMode.PER_REPLICA + and options.experimental_place_dataset_on_device): + raise ValueError( + "When `experimental_place_dataset_on_device` is set for dataset " + "placement, you must also specify `PER_REPLICA` for the " + "replication mode") + + if (options is not None and + options.experimental_replication_mode == InputReplicationMode.PER_REPLICA + and options.experimental_prefetch_to_device and + options.experimental_place_dataset_on_device): + raise ValueError( + "`experimental_place_dataset_on_device` can not be set to True " + "when experimental_prefetch_to_device is True and " + "replication mode is set to `PER_REPLICA`") + if tf2.enabled(): - return DistributedDatasetsFromFunction( - dataset_fn, - input_workers, - input_contexts, - strategy) + return DistributedDatasetsFromFunction(dataset_fn, input_workers, + input_contexts, strategy, options) else: return DistributedDatasetsFromFunctionV1( dataset_fn, input_workers, input_contexts, - strategy) + strategy, + options) @tf_export("distribute.DistributedIterator", v1=[]) @@ -535,7 +559,8 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False): flattened_data = [] for per_worker_data in replicas: flattened_data.extend(per_worker_data) - replicas = distribute_utils.regroup(flattened_data) + replicas = _create_per_replica( + flattened_data, strategy, get_next_as_optional=True) # Run an all-reduce to see whether any worker has values. # TODO(b/131423105): we should be able to short-cut the all-reduce in some @@ -635,7 +660,8 @@ class DistributedIteratorBase(DistributedIteratorInterface): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_static_shapes(new_name)) - return distribute_utils.regroup(replicas) + return _create_per_replica( + replicas, self._strategy, get_next_as_optional=False) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): @@ -669,7 +695,8 @@ class DistributedIteratorBase(DistributedIteratorInterface): results.append(result) replicas = results - return distribute_utils.regroup(replicas) + return _create_per_replica(replicas, self._strategy, + self._enable_get_next_as_optional) class DistributedIteratorV1(DistributedIteratorBase): @@ -869,11 +896,25 @@ class DistributedIterator(DistributedIteratorBase, @property def element_spec(self): + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if (self._enable_get_next_as_optional and + self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access + return nest.map_structure( + _rebatch_as_dynamic, self._element_spec, expand_composites=False) return self._element_spec @property def _type_spec(self): - return DistributedIteratorSpec(self._input_workers, self.element_spec, + # Note that we use actual element_spec to create DistributedIteratorSpec, + # to be consistent with the underlying iterators' specs. + # TODO(b/163362689): remove the comment after the bug if fixed. + return DistributedIteratorSpec(self._input_workers, self._element_spec, self._strategy, self._enable_get_next_as_optional) @@ -989,7 +1030,7 @@ class DistributedDataset(_IterableInput): self._input_workers = input_workers self._strategy = strategy self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, dataset.element_spec) + self._strategy, dataset) self._element_spec = _create_distributed_tensor_spec( self._strategy, self._cloned_datasets[0].element_spec) @@ -1073,7 +1114,7 @@ class DistributedDataset(_IterableInput): worker_iterators, self._strategy, enable_get_next_as_optional=self._enable_get_next_as_optional) - iterator._element_spec = self.element_spec # pylint: disable=protected-access + iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish # initialization before passing to a multi device function, add a sync point @@ -1086,6 +1127,17 @@ class DistributedDataset(_IterableInput): @property def element_spec(self): """The type specification of an element of this dataset.""" + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if (self._enable_get_next_as_optional and + self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access + return nest.map_structure( + _rebatch_as_dynamic, self._element_spec, expand_composites=False) return self._element_spec @@ -1178,7 +1230,8 @@ class DistributedDatasetV1(DistributedDataset): class DistributedDatasetsFromFunction(_IterableInput): """Inputs created from dataset function.""" - def __init__(self, dataset_fn, input_workers, input_contexts, strategy): + def __init__(self, dataset_fn, input_workers, input_contexts, strategy, + options): """Makes an iterable from datasets created by the given function. Args: @@ -1189,6 +1242,8 @@ class DistributedDatasetsFromFunction(_IterableInput): `worker_device_pairs`. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. + options: `tf.distribute.InputOptions` used to control options on how this + dataset is distributed. """ super(DistributedDatasetsFromFunction, self).__init__( input_workers=input_workers) @@ -1202,12 +1257,12 @@ class DistributedDatasetsFromFunction(_IterableInput): self._input_workers = input_workers self._input_contexts = input_contexts self._strategy = strategy + self._options = options self._datasets, element_spec = ( - _create_datasets_per_worker_with_input_context(self._input_contexts, - self._input_workers, - dataset_fn)) + _create_datasets_from_function_with_input_context( + self._input_contexts, self._input_workers, dataset_fn)) self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, element_spec) + self._strategy, self._datasets[0]) self._element_spec = _create_distributed_tensor_spec( self._strategy, element_spec) @@ -1220,11 +1275,10 @@ class DistributedDatasetsFromFunction(_IterableInput): # out this change. enable_legacy_iterators = getattr(self._strategy, "_enable_legacy_iterators", False) - iterators = _create_iterators_per_worker(self._datasets, self._input_workers, - enable_legacy_iterators) - + enable_legacy_iterators, + self._options) if enable_legacy_iterators: iterator = DistributedIteratorV1( self._input_workers, @@ -1233,9 +1287,9 @@ class DistributedDatasetsFromFunction(_IterableInput): enable_get_next_as_optional=self._enable_get_next_as_optional) else: iterator = DistributedIterator( - self._input_workers, - iterators, - self._strategy, + input_workers=self._input_workers, + iterators=iterators, + strategy=self._strategy, enable_get_next_as_optional=self._enable_get_next_as_optional) iterator._element_spec = self._element_spec # pylint: disable=protected-access @@ -1253,6 +1307,17 @@ class DistributedDatasetsFromFunction(_IterableInput): @property def element_spec(self): """The type specification of an element of this dataset.""" + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if (self._enable_get_next_as_optional and + self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access + return nest.map_structure( + _rebatch_as_dynamic, self._element_spec, expand_composites=False) return self._element_spec @@ -1350,6 +1415,7 @@ class InputFunctionIterator(DistributedIteratorV1): super(InputFunctionIterator, self).__init__( input_workers, iterators, strategy, enable_get_next_as_optional=False) + self._enable_get_next_as_optional = False # TODO(anjalisridhar): This class will soon be removed and users should move @@ -1475,7 +1541,7 @@ def _recover_shape_fn(data, value_structure): class _SingleWorkerDatasetIteratorBase(object): """Iterator for a single `tf.data.Dataset`.""" - def __init__(self, dataset, worker, devices): + def __init__(self, dataset, worker, devices, options=None): """Create iterator for the `dataset` to fetch data to worker's `devices` . A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch @@ -1485,21 +1551,46 @@ class _SingleWorkerDatasetIteratorBase(object): dataset: A `tf.data.Dataset` instance. worker: Worker on which ops should be created. devices: Distribute data from `dataset` to these devices. + options: options. """ self._dataset = dataset self._worker = worker self._devices = devices self._element_spec = dataset.element_spec + self._options = options self._make_iterator() def _make_iterator(self): raise NotImplementedError("must be implemented in descendants") + def _format_data_list_with_options(self, data_list): + """Change the data in to a list type if required. + + The OwnedMultiDeviceIterator returns the list data type, + while the PER_REPLICA iterator (when used with prefetch disabled) + returns without the enclosed list. This is to fix the inconsistency. + Args: + data_list: data_list + Returns: + list + """ + if (self._options and self._options.experimental_replication_mode == + InputReplicationMode.PER_REPLICA and + not self._options.experimental_prefetch_to_device): + return [data_list] + else: + return data_list + def get_next(self, device, name=None): """Get next element for the given device.""" del name with ops.device(self._worker): - return self._iterator.get_next(device) + if isinstance(self._iterator, + (multi_device_iterator_ops.OwnedMultiDeviceIterator, + multi_device_iterator_ops.MultiDeviceIterator)): + return self._iterator.get_next(device) + else: + return self._iterator.get_next() def get_next_as_list_static_shapes(self, name=None): """Get next element from the underlying iterator. @@ -1516,7 +1607,7 @@ class _SingleWorkerDatasetIteratorBase(object): """ del name with ops.device(self._worker): - return self._iterator.get_next() + return self._format_data_list_with_options(self._iterator.get_next()) def get_next_as_list(self, name=None): """Get next element from underlying iterator. @@ -1536,7 +1627,8 @@ class _SingleWorkerDatasetIteratorBase(object): """ del name with ops.device(self._worker): - data_list = self._iterator.get_next_as_optional() + data_list = self._format_data_list_with_options( + self._iterator.get_next_as_optional()) result = [] for i, data in enumerate(data_list): # Place the condition op in the same device as the data so the data @@ -1616,8 +1708,13 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, composite_tensor.CompositeTensor): """Iterator for a DistributedDataset instance.""" - def __init__(self, dataset=None, worker=None, devices=None, components=None, - element_spec=None): + def __init__(self, + dataset=None, + worker=None, + devices=None, + components=None, + element_spec=None, + options=None): """Create iterator for the `dataset` to fetch data to worker's `devices` . `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the @@ -1633,6 +1730,8 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, _SingleWorkerOwnedDatasetIterator from. element_spec: A nested structure of `TypeSpec` objects that represents the type specification of elements of the iterator. + options: `tf.distribute.InputOptions` used to control options on how this + dataset is distributed. """ if worker is None or devices is None: raise ValueError("Both `worker` and `devices` should be provided") @@ -1640,6 +1739,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, error_message = ("Either `dataset` or both `components` and `element_spec` " "need to be provided.") + self._options = options if dataset is None: if (components is None or element_spec is None): raise ValueError(error_message) @@ -1650,18 +1750,25 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, else: if (components is not None or element_spec is not None): raise ValueError(error_message) - super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker, - devices) + super(_SingleWorkerOwnedDatasetIterator, + self).__init__(dataset, worker, devices, options) def _make_iterator(self): """Make appropriate iterator on the dataset.""" if not self._worker: raise ValueError("Worked device must be specified when creating an " "owned iterator.") - host_device = device_util.get_host_for_device(self._worker) - with ops.device(self._worker): - self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( - self._dataset, self._devices, source_device=host_device) + if (self._options is None or self._options.experimental_replication_mode == + InputReplicationMode.PER_WORKER or + (self._options.experimental_replication_mode == InputReplicationMode + .PER_REPLICA and self._options.experimental_prefetch_to_device)): + host_device = device_util.get_host_for_device(self._worker) + with ops.device(self._worker): + self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( + self._dataset, self._devices, source_device=host_device) + else: + with ops.device(self._worker): + self._iterator = iter(self._dataset) @property def element_spec(self): @@ -1782,19 +1889,23 @@ class _SingleWorkerCallableIterator(object): return [] -def _create_iterators_per_worker(worker_datasets, input_workers, - enable_legacy_iterators): +def _create_iterators_per_worker(worker_datasets, + input_workers, + enable_legacy_iterators, + options=None): """Create a multidevice iterator on each of the workers.""" assert isinstance(input_workers, InputWorkers) - assert len(worker_datasets) == len(input_workers.worker_devices) iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) if tf2.enabled() and not enable_legacy_iterators: - iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker, - worker_devices) + iterator = _SingleWorkerOwnedDatasetIterator( + dataset=worker_datasets[i], + worker=worker, + devices=worker_devices, + options=options) else: iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, worker_devices) @@ -1802,8 +1913,9 @@ def _create_iterators_per_worker(worker_datasets, input_workers, return iterators -def _create_datasets_per_worker_with_input_context(input_contexts, - input_workers, dataset_fn): +def _create_datasets_from_function_with_input_context(input_contexts, + input_workers, + dataset_fn): """Create device datasets per worker given a dataset function.""" datasets = [] for i, ctx in enumerate(input_contexts): @@ -1993,13 +2105,14 @@ def _create_distributed_tensor_spec(strategy, tensor_spec): """ num_replicas = len(strategy.extended.worker_devices) - # If the number of devices used in the strategy is just 1 then we return - # the tensor_spec as is. - if num_replicas == 1: + # For one device strategy that is not MultiWorkerMirroredStrategy, return the + # tensor_spec as is, since we don't wrap the output with PerReplica in this + # case. + # TODO(b/166464552): remove after we always wrap for all strategies. + if not _always_wrap(strategy): return tensor_spec - # If the number of devices is greater than 1 then we assume the input to - # tf.function is a per replica type. + # For other cases we assume the input to tf.function is a per replica type. def _get_value_per_replica(tensor_spec_per_input): value_specs = [tensor_spec_per_input for _ in range(num_replicas)] return values.PerReplicaSpec(*value_specs) @@ -2015,7 +2128,7 @@ def _replace_per_replica_spec(spec, i): return spec -def _enable_get_next_as_optional(strategy, element_spec): +def _enable_get_next_as_optional(strategy, dataset): """Returns whether to enable using partial batch handling.""" # TODO(b/133073708): we currently need a flag to control the usage because # there is a performance difference between get_next() and @@ -2027,5 +2140,81 @@ def _enable_get_next_as_optional(strategy, element_spec): if not getattr(strategy.extended, "experimental_enable_get_next_as_optional", False): return False + + if context.executing_eagerly() and cardinality.cardinality( + dataset).numpy() == cardinality.INFINITE: + # If the dataset is inifinite, we don't need to enable last partial batch + # support. Currently the logic only applies to the case that distributed + # dataset is created in eager mode, as we need to evaluate the dataset + # cardinality. + return False + return not _is_statically_shaped( - element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access + dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access + + +def _create_per_replica(value_list, strategy, get_next_as_optional): + """Creates a PerReplica. + + For strategies other than OneDeviceStrategy, it creates a PerReplica whose + type spec is set to the element spec of the dataset. This helps avoid + retracing for partial batches. Retracing is problematic for multi client when + different client retraces different time, since retracing changes the + collective keys in the tf.function, and causes mismatches among clients. + + For single client strategies, this simply calls distribute_utils.regroup(). + + Args: + value_list: a list of values, one for each replica. + strategy: the `tf.distribute.Strategy`. + get_next_as_optional: whether last partial batch handling is enabled. + + Returns: + a structure of PerReplica. + + """ + # TODO(b/166464552): always wrap for all one device strategies as well. + always_wrap = _always_wrap(strategy) + per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) + + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access + # Use expand_composites=False since we don't want to expand PerReplica, + # which is a CompositeTensor. + flat_per_replicas = nest.flatten(per_replicas, expand_composites=False) + flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas] + for per_replica, spec in zip(flat_per_replicas, flat_spec): + per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access + per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas) + + return per_replicas + + +def _always_wrap(strategy): + """Returns whether to always wrap the values in a DistributedValues.""" + return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access + strategy.extended.worker_devices) > 1 + + +def _rebatch_as_dynamic(per_replica_spec): + """Rebatch the spec to have a dynamic batch dimension.""" + assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec + + # pylint: disable=protected-access + def _rebatch(spec): + # Rebatch if possible. + try: + return spec._unbatch()._batch(None) + except ValueError: + pass + return spec + + return values.PerReplicaSpec( + *nest.map_structure(_rebatch, per_replica_spec._value_specs)) + # pylint: enable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 5abd6f483d3..3d9bce327c5 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -557,7 +557,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, iterator = iter(dist_dataset) for i, element in enumerate(iterator): - self.assertEqual(i, element.numpy()) + self.assertAllEqual(distribution.experimental_local_results(element), [i]) @combinations.generate( combinations.combine( @@ -1421,5 +1421,198 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, input_context=distribution.extended._make_input_context()) +class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, + parameterized.TestCase): + """Tests for PER_WORKER and PER_REPLICA's InputOptions variants.""" + + def setUp(self): + context._reset_context() + strategy_combinations.set_virtual_cpus_to_at_least(3) + super(DistributedIteratorPerDeviceTest, self).setUp() + + @combinations.generate( + combinations.combine( + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_WORKER), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + ], + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ])) + def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution, + input_options): + + def dataset_fn(input_context): # pylint: disable=[unused-argument] + return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) + + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn, input_options) + + for x in ds: + assert x.values[0].device == distribution.extended.worker_devices[0] + assert x.values[0].backing_device == distribution.extended.worker_devices[ + 0] + assert x.values[1].device == distribution.extended.worker_devices[1] + assert x.values[1].backing_device == distribution.extended.worker_devices[ + 1] + + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_WORKER) + ], + mode=["eager"], + )) + def testDevicePlacementForPerWorkerValuesWithoutPrefetch( + self, distribution, input_options): + + def dataset_fn(input_context): + return dataset_ops.Dataset.from_tensor_slices( + np.full(4, input_context.input_pipeline_id)) + + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn, input_options) + + for x in ds: + x = distribution.run(lambda inputs: inputs, args=(x,)) + assert x.values[ + 0].device == "/job:localhost/replica:0/task:0/device:CPU:0" + assert x.values[ + 0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" + assert x.values[ + 1].device == "/job:localhost/replica:0/task:0/device:CPU:0" + assert x.values[ + 1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" + + @combinations.generate( + combinations.combine( + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=True, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_WORKER), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=True, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA) + ], + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ])) + def testDevicePlacementForInvalidCombinations(self, distribution, + input_options): + + def dataset_fn(input_context): + return dataset_ops.Dataset.from_tensor_slices( + np.full(4, input_context.input_pipeline_id)) + + with self.assertRaises(ValueError): + distribution.experimental_distribute_datasets_from_function( + dataset_fn, input_options) + + @combinations.generate( + combinations.combine( + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_WORKER), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_WORKER), + ], + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ])) + def testOutputValuesForPerWorkerInputOptions(self, distribution, + input_options): + + def dataset_fn(input_context): + return dataset_ops.Dataset.from_tensor_slices( + np.arange(1, 11).reshape( + (2, 5)) * (input_context.input_pipeline_id + 1)) + + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn, input_options) + + # validating the values + x = next(iter(ds)) + assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) + assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) + + @combinations.generate( + combinations.combine( + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=True, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + ], + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ])) + def testOutputValuesForPerReplicaInputOptions(self, distribution, + input_options): + + def dataset_fn(input_context): + return dataset_ops.Dataset.from_tensor_slices( + np.arange(1, 10) * (input_context.input_pipeline_id + 1)) + + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn, input_options) + expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + for i, x in enumerate(ds): + # validating the values + assert x.values[0].numpy() == expected[i] + assert x.values[1].numpy() == expected[i] * 2 + loop_num = i + assert loop_num == len(expected) - 1 + + if __name__ == "__main__": test_util.main() diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index bc6ac811bbb..940949efd87 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -18,15 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized import numpy as np from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values from tensorflow.python.eager import def_function @@ -37,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib from tensorflow.python.util import nest @@ -116,14 +120,17 @@ class DistributedIteratorTest(test.TestCase, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], - enable_get_next_as_optional=[True, False])) + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + tf_api_version=2, + )) def testDoesNotTriggerFunctionTracing(self, input_type, distribution, - enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - + enable_get_next_as_optional, + drop_remainder): trace_count = [0] @def_function.function @@ -135,7 +142,8 @@ class DistributedIteratorTest(test.TestCase, counter += 1 return counter - dataset = dataset_ops.DatasetV2.range(10).batch(2) + dataset = dataset_ops.DatasetV2.range(10).batch( + 2, drop_remainder=drop_remainder) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) @@ -161,27 +169,79 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], - input_type=["dataset", "dataset_fn"], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], )) - def testInputSignatureForPerReplicaValues(self, distribution, input_type): - def dataset_fn(ctx): - del ctx # unused - return dataset_ops.DatasetV2.from_tensor_slices( - np.ones([10, 12]).astype(np.float32)).batch(4) + def testInputSignatureForPerReplicaValues(self, distribution, + enable_get_next_as_optional, + drop_remainder): + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = dataset_ops.DatasetV2.from_tensor_slices( + np.ones([9, 12]).astype(np.float32)).batch( + 4, drop_remainder=drop_remainder) + ds = distribution.experimental_distribute_dataset(ds) + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) - if input_type == "dataset": - ds = distribution.experimental_distribute_dataset( - dataset_fn(distribute_lib.InputContext())) - type_spec = ds.element_spec - else: - ds = distribution.distribute_datasets_from_function(dataset_fn) - iterator = iter(ds) - _check_type_spec_structure(iterator) - type_spec = iterator.element_spec + @def_function.function(input_signature=[element_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) - @def_function.function(input_signature=[type_spec]) + for x in ds: + process_inputs(x) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromFunctionInputSignatureForPerReplicaValues( + self, distribution, enable_get_next_as_optional, drop_remainder): + # Create files that produce partial/empty batches at different batch. Note + # that some worker will get empty batches even when drop_remainder=True. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).map( + string_ops.string_to_number).batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn) + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) + + @def_function.function(input_signature=[element_spec]) def process_inputs(inputs): distribution.run(lambda inputs: inputs, args=(inputs,)) @@ -247,6 +307,149 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2)) self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1)) + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + drop_remainder=[True, False], + )) + def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution, + drop_remainder): + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = True + # Total dataset size 5 allows us to have full batches, partial batches and + # empty batches. + dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch( + 4, drop_remainder=drop_remainder) + dataset = distribution.experimental_distribute_dataset(dataset) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + drop_remainder=[True, False], + )) + def testFromDatasetFileShardingDoesNotTriggerFunctionTracing( + self, distribution, drop_remainder): + # Create files that produce partial/empty batches at different batch. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = True + dataset = readers.TextLineDatasetV2([fname1, fname2]).batch( + 4, drop_remainder=drop_remainder) + dataset = distribution.experimental_distribute_dataset(dataset) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + drop_remainder=[True, False], + )) + def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution, + drop_remainder): + + def dataset_fn(input_context): + # Total dataset size 5 allows us to have full batches, partial batches and + # empty batches. + dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))) + dataset = dataset.batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + return dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = True + dataset = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + drop_remainder=[True, False], + )) + def testFromFunctionFileShardingDoesNotTriggerFunctionTracing( + self, distribution, drop_remainder): + # Create files that produce partial/empty batches at different batch. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = True + dataset = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + class RaggedTensorDistributedIteratorTest(test.TestCase, parameterized.TestCase): @@ -254,14 +457,14 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testTypeSpec(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator has CompositeTensor support in " - "TF 2.0 only.") ctx = distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(8) # Use 20 which isn't divisible by 8 to test partial batch behavior. @@ -313,16 +516,16 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], + enable_get_next_as_optional=[True, False])) def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - ctx = distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(8) # Use 20 which isn't divisible by 8 to test partial batch behavior. @@ -366,17 +569,17 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testDoesNotTriggerFunctionTracing(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - trace_count = [0] @def_function.function @@ -432,5 +635,11 @@ def _check_type_spec_structure(x): nest.assert_same_structure(x, x._type_spec, expand_composites=True) +def _create_text_file(fname, num_lines): + with open(fname, "w") as f: + for i in range(num_lines): + f.write("%d\n" % i) + + if __name__ == "__main__": - test.main() + test_util.main() diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD index d3a71e0136d..c9947be442b 100644 --- a/tensorflow/python/distribute/integration_test/BUILD +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -9,6 +9,9 @@ package( distribute_py_test( name = "saved_model_test", srcs = ["saved_model_test.py"], + tags = [ + "no_windows", # TODO(b/171350360) + ], deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:lookup_ops", @@ -33,6 +36,7 @@ cuda_py_test( shard_count = 2, tags = [ "multi_and_single_gpu", + "no_oss_py38", #TODO(b/171435331) ], deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index 02dee6f6adb..283a76300a9 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -37,6 +37,8 @@ from tensorflow.python.eager import test mwms_lib.CollectiveAllReduceExtended._enable_check_health = True mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0 +# This is needed for OSS, which issues all RPCs with fail_fast=false by default. +mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1 def get_attempt(strategy, attempts): diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 658b45ecec6..6db1aeb6ca3 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -313,6 +314,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): assert devices, ("Got an empty `devices` list and unable to recognize " "any local devices.") self._cross_device_ops = cross_device_ops + self._communication_options = collective_util.Options() self._initialize_strategy(devices) # TODO(b/128995245): Enable last partial batch support in graph mode. @@ -339,6 +341,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): self._devices = tuple(device_util.canonicalize(d) for d in devices) self._input_workers_devices = ( (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) + self._inferred_cross_device_ops = None if self._cross_device_ops else ( cross_device_ops_lib.select_cross_device_ops(devices)) self._host_input_device = numpy_dataset.SingleDevice( @@ -394,12 +397,27 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): logging.info("Using MirroredStrategy with remote devices %r", devices) def _input_workers_with_options(self, options=None): - if not options or options.experimental_prefetch_to_device: + if not options: + return input_lib.InputWorkers(self._input_workers_devices) + if (options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + if options.experimental_place_dataset_on_device: + self._input_workers_devices = ( + tuple( + (device_util.canonicalize(d, d), (d,)) for d in self._devices)) + else: + self._input_workers_devices = ( + tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) + for d in self._devices)) return input_lib.InputWorkers(self._input_workers_devices) else: - return input_lib.InputWorkers( - [(host_device, (host_device,) * len(compute_devices)) for - host_device, compute_devices in self._input_workers_devices]) + if not options.experimental_prefetch_to_device: + return input_lib.InputWorkers([ + (host_device, (host_device,) * len(compute_devices)) + for host_device, compute_devices in self._input_workers_devices + ]) + else: + return input_lib.InputWorkers(self._input_workers_devices) @property def _input_workers(self): @@ -497,6 +515,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): self._container_strategy()) def _experimental_distribute_dataset(self, dataset, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function`." + ) return input_lib.get_distributed_dataset( dataset, self._input_workers_with_options(options), @@ -508,8 +533,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): numpy_input, self._host_input_device, session) def _distribute_datasets_from_function(self, dataset_fn, options): - input_contexts = [] input_workers = self._input_workers_with_options(options) + input_contexts = [] num_workers = input_workers.num_workers for i in range(num_workers): input_contexts.append(distribute_lib.InputContext( @@ -518,10 +543,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): num_replicas_in_sync=self._num_replicas_in_sync)) return input_lib.get_distributed_datasets_from_function( - dataset_fn, - input_workers, - input_contexts, - self._container_strategy()) + dataset_fn, input_workers, input_contexts, self._container_strategy(), + options) def _experimental_distribute_values_from_function(self, value_fn): per_replica_values = [] @@ -632,8 +655,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): del value # Unused. return self._cross_device_ops or self._inferred_cross_device_ops - def _gather_to_implementation(self, value, destinations, axis, - experimental_hints): + def _gather_to_implementation(self, value, destinations, axis, options): if not isinstance(value, values.DistributedValues): # ReductionToOneDevice._gather accepts DistributedValues only. return value @@ -641,9 +663,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): value, destinations=destinations, axis=axis, - experimental_hints=experimental_hints) + options=self._communication_options.merge(options)) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): if (distribute_utils.is_mirrored(value) and reduce_op == reduce_util.ReduceOp.MEAN): return value @@ -659,10 +681,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): reduce_op, value, destinations=destinations, - experimental_hints=experimental_hints) + options=self._communication_options.merge(options)) - def _batch_reduce_to(self, reduce_op, value_destination_pairs, - experimental_hints): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): cross_device_ops = None for value, _ in value_destination_pairs: if cross_device_ops is None: @@ -670,8 +691,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): elif cross_device_ops is not self._get_cross_device_ops(value): raise ValueError("inputs to batch_reduce_to must be either all on the " "the host or all on the compute devices") - return cross_device_ops.batch_reduce(reduce_op, value_destination_pairs, - experimental_hints) + return cross_device_ops.batch_reduce( + reduce_op, + value_destination_pairs, + options=self._communication_options.merge(options)) def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py index 89021448eb2..81ee53a285f 100644 --- a/tensorflow/python/distribute/multi_process_lib.py +++ b/tensorflow/python/distribute/multi_process_lib.py @@ -12,41 +12,161 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""OSS multi-process library to be implemented.""" +"""Library for multi-process testing.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -import multiprocessing as _multiprocessing +import multiprocessing import os +import platform +import sys import unittest +from absl import app from tensorflow.python.eager import test -try: - multiprocessing = _multiprocessing.get_context('forkserver') -except ValueError: - # forkserver is not available on Windows. - multiprocessing = _multiprocessing.get_context('spawn') +def is_oss(): + """Returns whether the test is run under OSS.""" + return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] -class Process(object): - """A process simulating a worker for testing multi-worker training.""" +def _is_enabled(): + # Note that flags may not be parsed at this point and simply importing the + # flags module causes a variety of unusual errors. + tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')] + if is_oss() and tpu_args: + return False + if sys.version_info == (3, 8) and platform.system() == 'Linux': + return False # TODO(b/171242147) + return sys.platform != 'win32' + + +class _AbslProcess: + """A process that runs using absl.app.run.""" def __init__(self, *args, **kwargs): - del args, kwargs - raise unittest.SkipTest( - 'TODO(b/150264776): Implement OSS version of `multi_process_lib`') + super(_AbslProcess, self).__init__(*args, **kwargs) + # Monkey-patch that is carried over into the spawned process by pickle. + self._run_impl = getattr(self, 'run') + self.run = self._run_with_absl + + def _run_with_absl(self): + app.run(lambda _: self._run_impl()) + + +if _is_enabled(): + + class AbslForkServerProcess(_AbslProcess, + multiprocessing.context.ForkServerProcess): + """An absl-compatible Forkserver process. + + Note: Forkserver is not available in windows. + """ + + class AbslForkServerContext(multiprocessing.context.ForkServerContext): + _name = 'absl_forkserver' + Process = AbslForkServerProcess # pylint: disable=invalid-name + + multiprocessing = AbslForkServerContext() + Process = multiprocessing.Process + +else: + + class Process(object): + """A process that skips test (until windows is supported).""" + + def __init__(self, *args, **kwargs): + del args, kwargs + raise unittest.SkipTest( + 'TODO(b/150264776): Windows is not supported in MultiProcessRunner.') + + +_test_main_called = False + + +def _set_spawn_exe_path(): + """Set the path to the executable for spawned processes. + + This utility searches for the binary the parent process is using, and sets + the executable of multiprocessing's context accordingly. + + Raises: + RuntimeError: If the binary path cannot be determined. + """ + # TODO(b/150264776): This does not work with Windows. Find a solution. + if sys.argv[0].endswith('.py'): + # If all we have is a python module path, we'll need to make a guess for + # the actual executable path. Since the binary path may correspond to the + # parent's path of the python module, we are making guesses by reducing + # directories one at a time. E.g., + # tensorflow/python/some/path/my_test.py + # -> tensorflow/python/some/path/my_test + # -> tensorflow/python/some/my_test + # -> tensorflow/python/my_test + path_to_use = None + guess_path = sys.argv[0][:-3] + guess_path = guess_path.split(os.sep) + for path_reduction in range(-1, -len(guess_path), -1): + possible_path = os.sep.join(guess_path[:path_reduction] + + [guess_path[-1]]) + if os.access(possible_path, os.X_OK): + path_to_use = possible_path + break + # The binary can possibly have _gpu suffix. + possible_path += '_gpu' + if os.access(possible_path, os.X_OK): + path_to_use = possible_path + break + if path_to_use is None: + raise RuntimeError('Cannot determine binary path') + sys.argv[0] = path_to_use + # Note that this sets the executable for *all* contexts. + multiprocessing.get_context().set_executable(sys.argv[0]) + + +def _if_spawn_run_and_exit(): + """If spawned process, run requested spawn task and exit. Else a no-op.""" + + # `multiprocessing` module passes a script "from multiprocessing.x import y" + # to subprocess, followed by a main function call. We use this to tell if + # the process is spawned. Examples of x are "forkserver" or + # "semaphore_tracker". + is_spawned = ('-c' in sys.argv[1:] and + sys.argv[sys.argv.index('-c') + + 1].startswith('from multiprocessing.')) + + if not is_spawned: + return + cmd = sys.argv[sys.argv.index('-c') + 1] + # As a subprocess, we disregarding all other interpreter command line + # arguments. + sys.argv = sys.argv[0:1] + + # Run the specified command - this is expected to be one of: + # 1. Spawn the process for semaphore tracker. + # 2. Spawn the initial process for forkserver. + # 3. Spawn any process as requested by the "spawn" method. + exec(cmd) # pylint: disable=exec-used + sys.exit(0) # Semaphore tracker doesn't explicitly sys.exit. def test_main(): """Main function to be called within `__main__` of a test file.""" + global _test_main_called + _test_main_called = True + os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' + + if _is_enabled(): + _set_spawn_exe_path() + _if_spawn_run_and_exit() + + # Only runs test.main() if not spawned process. test.main() def initialized(): """Returns whether the module is initialized.""" - return True + return _test_main_called diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 107aa1a6a48..4c72969c995 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -699,6 +699,8 @@ class MultiProcessRunner(object): sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) for (task_type, task_id), p in self._processes.items(): if p.exitcode is not None: + logging.info('%s-%d has already exited. Not terminating.', task_type, + task_id) continue try: os.kill(p.pid, sig) @@ -866,6 +868,11 @@ def _shutdown_all_pool_runners(): pool.shutdown() +def is_oss(): + """Returns whether the test is run under OSS.""" + return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] + + class MultiProcessPoolRunner(object): """A utility class to start a process pool to simulate a cluster. @@ -919,6 +926,9 @@ class MultiProcessPoolRunner(object): if dill is None: raise unittest.SkipTest( 'TODO(b/150264776): Resolve dependency issue in CI') + if is_oss(): + raise unittest.SkipTest( + 'TODO(b/170360740): MultiProcessPoolRunner timing out in OSS') self._runner = MultiProcessRunner( fn=lambda: None, diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index c164bd490e3..c5cb3e4e9f1 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -319,6 +319,9 @@ class MultiProcessRunnerTest(test.TestCase): def test_seg_fault_raises_error(self): + if multi_process_runner.is_oss(): + self.skipTest('TODO(b/171004637): Failing in OSS') + def fn_expected_to_seg_fault(): ctypes.string_at(0) # Intentionally made seg fault. @@ -331,10 +334,14 @@ class MultiProcessRunnerTest(test.TestCase): self.assertIn('Subprocess worker-0 exited with exit code', str(cm.exception)) list_to_assert = cm.exception.mpr_result.stdout - self.assertTrue(any('SIGSEGV' in line for line in list_to_assert)) + self.assertTrue( + any('Segmentation fault' in line for line in list_to_assert)) def test_seg_fault_in_chief_raises_error(self): + if multi_process_runner.is_oss(): + self.skipTest('TODO(b/171004637): Failing in OSS') + def fn_expected_to_seg_fault(): if multi_worker_test_base.get_task_type() == 'worker': time.sleep(10000) @@ -350,7 +357,8 @@ class MultiProcessRunnerTest(test.TestCase): self.assertIn('Subprocess chief-0 exited with exit code', str(cm.exception)) list_to_assert = cm.exception.mpr_result.stdout - self.assertTrue(any('SIGSEGV' in line for line in list_to_assert)) + self.assertTrue( + any('Segmentation fault' in line for line in list_to_assert)) def test_exit_code_is_reported_by_chief_subprocess(self): @@ -514,6 +522,9 @@ class MultiProcessRunnerTest(test.TestCase): def test_timeout_none(self): + if multi_process_runner.is_oss(): + self.skipTest('Intentionally skipping longer test in OSS.') + def fn(): time.sleep(250) raise ValueError('Worker 0 errored') @@ -579,9 +590,13 @@ class MultiProcessPoolRunnerTest(test.TestCase): self.assertAllEqual(result, [1, 1]) def test_global_pool(self): + if multi_process_runner.is_oss(): + self.skipTest('TODO(b/170360740): Failing in OSS') _global_pool.run(fn_that_does_nothing) def test_nested_pool(self): + if multi_process_runner.is_oss(): + self.skipTest('TODO(b/170360740): Failing in OSS') def fn(): # This runs in sub processes, so they are each using their own diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py index 9e56e6d1bf7..5809182b2a8 100644 --- a/tensorflow/python/distribute/multi_worker_test_base.py +++ b/tensorflow/python/distribute/multi_worker_test_base.py @@ -234,6 +234,11 @@ class MultiProcessCluster(object): server_config = config_pb2.ConfigProto() server_config.device_count['GPU'] = 0 + # Set the environment variable to prevent hanging upon job failure and + # restart. Note that it defaults to 'use_caller' at Google, but defaults + # to False in OSS. + os.environ['GRPC_FAIL_FAST'] = 'use_caller' + server_lib.Server( cluster_spec, job_name=task_type, diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 3d5175d9055..946735352a3 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -312,12 +312,26 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): def _experimental_distribute_dataset(self, dataset, options): # Note that split_batch_by argument is not passed because it is always 1 in # this strategy, and adding it adds unnecessary overhead to the dataset. + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function`." + ) return input_lib.get_distributed_dataset( dataset, self._input_workers_with_options(options), self._container_strategy()) def _distribute_datasets_from_function(self, dataset_fn, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function` " + "of tf.distribute.MirroredStrategy") return input_lib.get_distributed_datasets_from_function( dataset_fn, self._input_workers_with_options(options), @@ -379,13 +393,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): - del reduce_op, destinations, experimental_hints + def _reduce_to(self, reduce_op, value, destinations, options): + del reduce_op, destinations, options return value - def _gather_to_implementation(self, value, destinations, axis, - experimental_hints): - del destinations, axis, experimental_hints + def _gather_to_implementation(self, value, destinations, axis, options): + del destinations, axis, options return value def _update(self, var, fn, args, kwargs, group): diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 0cc2b21c3aa..312a3a483c7 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -47,9 +47,8 @@ from tensorflow.python.util.tf_export import tf_export _LOCAL_CPU = "/device:CPU:0" -# TODO(yuefengz): maybe cache variables on local CPU. -@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) -class ParameterServerStrategy(distribute_lib.Strategy): +@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring +class ParameterServerStrategyV1(distribute_lib.StrategyV1): """An asynchronous multi-worker parameter server tf.distribute strategy. This strategy requires two roles: workers and parameter servers. Variables and @@ -112,52 +111,51 @@ class ParameterServerStrategy(distribute_lib.Strategy): """ if cluster_resolver is None: cluster_resolver = TFConfigClusterResolver() - if not cluster_resolver.cluster_spec(): - raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.") - extended = ParameterServerStrategyExtended( - self, cluster_resolver=cluster_resolver) - super(ParameterServerStrategy, self).__init__(extended) - - def experimental_distribute_dataset(self, dataset, options=None): - self._raise_pss_error_if_eager() - super(ParameterServerStrategy, - self).experimental_distribute_dataset(dataset=dataset, - options=options) - - def distribute_datasets_from_function(self, dataset_fn, options=None): - self._raise_pss_error_if_eager() - super(ParameterServerStrategy, self).distribute_datasets_from_function( - dataset_fn=dataset_fn, options=options) - - def run(self, fn, args=(), kwargs=None, options=None): - self._raise_pss_error_if_eager() - super(ParameterServerStrategy, self).run( - fn, args=args, kwargs=kwargs, options=options) - - def scope(self): - self._raise_pss_error_if_eager() - return super(ParameterServerStrategy, self).scope() - - def _raise_pss_error_if_eager(self): - if context.executing_eagerly(): - raise NotImplementedError("ParameterServerStrategy currently only works " - "with the tf.Estimator API") - - -@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring -class ParameterServerStrategyV1(distribute_lib.StrategyV1): - - __doc__ = ParameterServerStrategy.__doc__ - - def __init__(self, cluster_resolver=None): - """Initializes this strategy.""" super(ParameterServerStrategyV1, self).__init__( ParameterServerStrategyExtended( self, cluster_resolver=cluster_resolver)) distribute_lib.distribution_strategy_gauge.get_cell("V1").set( "ParameterServerStrategy") - __init__.__doc__ = ParameterServerStrategy.__init__.__doc__ + def experimental_distribute_dataset(self, dataset, options=None): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function`." + ) + self._raise_pss_error_if_eager() + super(ParameterServerStrategyV1, + self).experimental_distribute_dataset(dataset=dataset, + options=options) + + def distribute_datasets_from_function(self, dataset_fn, options=None): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function` " + "of tf.distribute.MirroredStrategy") + self._raise_pss_error_if_eager() + super(ParameterServerStrategyV1, self).distribute_datasets_from_function( + dataset_fn=dataset_fn, options=options) + + def run(self, fn, args=(), kwargs=None, options=None): + self._raise_pss_error_if_eager() + super(ParameterServerStrategyV1, self).run( + fn, args=args, kwargs=kwargs, options=options) + + def scope(self): + self._raise_pss_error_if_eager() + return super(ParameterServerStrategyV1, self).scope() + + def _raise_pss_error_if_eager(self): + if context.executing_eagerly(): + raise NotImplementedError( + "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` " + "currently only works with the tf.Estimator API") # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. @@ -504,7 +502,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): (d, self._worker_device)) def _gather_to_implementation(self, value, destinations, axis, - experimental_hints): + options): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): return value @@ -512,27 +510,22 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): value, destinations=destinations, axis=axis, - experimental_hints=experimental_hints) + options=options) - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, self._num_replicas_in_sync) return self._cross_device_ops.reduce( - reduce_op, - value, - destinations=destinations, - experimental_hints=experimental_hints) + reduce_op, value, destinations=destinations, options=options) - def _batch_reduce_to(self, reduce_op, value_destination_pairs, - experimental_hints): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs, - experimental_hints) + value_destination_pairs, options) def _select_single_value(self, structured): """Select any single value in `structured`.""" diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index 1b4cd21c249..c196fb4ad94 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -84,7 +84,7 @@ def create_test_objects(cluster_spec=None, task_type=task_type, task_id=task_id, num_accelerators={'GPU': num_gpus}) - distribution = parameter_server_strategy.ParameterServerStrategy( + distribution = parameter_server_strategy.ParameterServerStrategyV1( cluster_resolver) target = 'grpc://' + cluster_spec[WORKER][task_id] else: @@ -748,7 +748,7 @@ class ParameterServerStrategyTest( task_type='worker', task_id=1, num_accelerators={'GPU': 0}) - strategy = parameter_server_strategy.ParameterServerStrategy( + strategy = parameter_server_strategy.ParameterServerStrategyV1( cluster_resolver) dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.]) diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index d215be0dd94..4ed2e8fa43e 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -37,82 +37,421 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export -# pylint: disable=protected-access +@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) class ParameterServerStrategyV2(distribute_lib.Strategy): - """An asynchronous multi-worker parameter server tf.distribute strategy. + """An multi-worker tf.distribute strategy with parameter servers. - Currently, `ParameterServerStrategyV2` is not supported to be used as a - standalone tf.distribute strategy. It should be used in conjunction with - `Client`. Please see `Client` for more information. + Parameter server training is a common data-parallel method to scale up a + machine learning model on multiple machines. A parameter server training + cluster consists of workers and parameter servers. Variables are created on + parameter servers and they are read and updated by workers in each step. + By default, workers read and update these variables independently without + synchronizing with each other. Under this configuration, it is known as + asynchronous training. - This is currently under development, and the API as well as implementation - is subject to changes. + In TensorFlow 2, we recommend a central coordiantion-based architecture for + parameter server training, where workers and parameter servers run a + `tf.distribute.Server` and there is another task that creates resources on + workers and parameter servers, dispatches functions, and coordinates the + training. We refer to this task as “coordinator”. The coordinator uses a + `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the + cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define + variables on parameter servers and computation on workers. + + For the training to work, the coordinator dispatches `tf.function`s to be + executed on remote workers. Upon receiving requests from + the coordinator, a worker executes the `tf.function` by reading the variables + from parameter servers, executing the ops, and updating the variables on the + parameter servers. Each of the worker only processes the requests from the + coordinator, and communicates with parameter servers, without direct + interactions with other workers in the cluster. + + As a result, failures of some workers do not prevent the cluster from + continuing the work, and this allows the cluster to train with instances that + can be occasionally unavailable (e.g. preemptible or spot instances). The + coordinator and parameter servers though, must be available at all times for + the cluster to make progress. + + Note that the coordinator is not one of the training workers. Instead, it + creates resources such as variables and datasets, dispatchs `tf.function`s, + saving checkpoints and so on. In addition to workers, parameter servers and + the coordinator, an optional evaluator can be run on the side that + periodically reads the checkpoints saved by the coordinator and runs + evaluations against each checkpoint. + + `tf.distribute.experimental.ParameterServerStrategy` has to work in + conjunction with a `tf.distribute.experimental.coordinator.ClusterCoordinator` + object. Standalone usage of + `tf.distribute.experimental.ParameterServerStrategy` without central + coordination is not supported at this time. + + __Example code for coordinator__ + + Here's an example usage of the API, with a custom training loop to train a + model. This code snippet is intended to be run on (the only) one task that + is designated as the coordinator. Note that `cluster_resolver`, + `variable_partitioner`, and `dataset_fn` arguments are explained in the + following "Cluster setup", "Variable partitioning", and "Dataset preparation" + sections. + + ```python + # Set the environment variable to allow reporting worker and ps failure to the + # coordinator. This a short-term workaround. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + # Prepare a strategy to use with the cluster and variable partitioning info. + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver=..., + variable_partitioner=...) + coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( + strategy=strategy) + + # Prepare a distribute dataset that will place datasets on the workers. + distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) + + with strategy.scope(): + model = ... + optimizer, metrics = ... # Keras optimizer/metrics are great choices + checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) + checkpoint_manager = tf.train.CheckpointManager( + checkpoint, checkpoint_dir, max_to_keep=2) + # `load_checkpoint` infers initial epoch from `optimizer.iterations`. + initial_epoch = load_checkpoint(checkpoint_manager) or 0 + + @tf.function + def worker_fn(iterator): + + def replica_fn(inputs): + batch_data, labels = inputs + # calculate gradient, applying gradient, metrics update etc. + + strategy.run(replica_fn, args=(next(iterator),)) + + for epoch in range(initial_epoch, num_epoch): + distributed_iterator = iter(distributed_dataset) # Reset iterator state. + for step in range(steps_per_epoch): + + # Asynchronously schedule the `worker_fn` to be executed on an arbitrary + # worker. This call returns immediately. + coordinator.schedule(worker_fn, args=(distributed_iterator,)) + + # `join` blocks until all scheduled `worker_fn`s finish execution. Once it + # returns, we can read the metrics and save checkpoints as needed. + coordinator.join() + logging.info('Metric result: %r', metrics.result()) + train_accuracy.reset_states() + checkpoint_manager.save() + ``` + + __Example code for worker and parameter servers__ + + In addition to the coordinator, there should be tasks designated as + "worker" or "ps". They should run the following code to start a TensorFlow + server, waiting for coordinator's requests: + + ```python + # Set the environment variable to allow reporting worker and ps failure to the + # coordinator. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves + # the cluster information. See below "Cluster setup" section. + cluster_resolver = ... + + server = tf.distribute.Server( + cluster_resolver.cluster_spec(), + job_name=cluster_resolver.task_type, + task_index=cluster_resolver.task_id, + protocol="grpc") + + # Blocking the process that starts a server from exiting. + server.join() + ``` + + __Cluster setup__ + + In order for the tasks in the cluster to know other tasks' addresses, + a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used + in coordinator, worker, and ps. The + `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing + the cluster information, as well as the task type and id of the current task. + See `tf.distribute.cluster_resolver.ClusterResolver` for more information. + + If `TF_CONFIG` environment variable is set, a + `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as + well. Note that for legacy reason, on some platform, "chief" is used as the + task type for the coordinator, as the following example demonstrates. Here we + set `TF_CONFIG` for the task designated as a parameter server (task type "ps") + and index 1 (the second task), in a cluster with 1 chief, 2 parameter servers, + and 3 workers. Note that the it needs to be set before the use of + `tf.distribute.cluster_resolver.TFConfigClusterResolver`. + + Example code for cluster setup: + ```python + os.environ['TF_CONFIG'] = ''' + { + "cluster": { + "chief": ["chief.example.com:2222"], + "ps": ["ps0.example.com:2222", "ps1.example.com:2222"], + "worker": ["worker0.example.com:2222", "worker1.example.com:2222", + "worker2.example.com:2222"] + }, + "task": { + "type": "ps", + "index": 1 + } + } + ''' + ``` + + If you prefer to run the same binary for all tasks, you will need to let the + binary branch into different roles at the beginning of the program: + ```python + os.environ["GRPC_FAIL_FAST"] = "use_caller" + cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() + + # If coordinator, create a strategy and start the training program. + if cluster_resolver.task_type == 'chief': + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver) + ... + + # If worker/ps, create a server + elif cluster_resolver.task_type in ("worker", "ps"): + server = tf.distribute.Server(...) + ... + ``` + Alternatively, you can also start a bunch of TensorFlow servers in advance and + connect to them later. The coordinator can be in the same cluster or on any + machine that has connectivity to workers and parameter server. This is covered + in our guide and tutorial. + + __Variable creation with `strategy.scope()`__ + + `tf.distribute.experimental.ParameterServerStrategy` follows the + `tf.distribute` API contract where variable creation is expected to be inside + the context manager returned by `strategy.scope()`, in order to be correctly + placed on parameter servers in a round-robin manner: + + ```python + # In this example, we're assuming having 3 ps. + strategy = tf.distribute.experimental.ParameterServerStrategy( + cluster_resolver=...) + coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( + strategy=strategy) + + # Variables should be created inside scope to be placed on parameter servers. + # If created outside scope such as `v1` here, it would be placed on the + # coordinator. + v1 = tf.Variable(initial_value=0.0) + + with strategy.scope(): + v2 = tf.Variable(initial_value=1.0) + v3 = tf.Variable(initial_value=2.0) + v4 = tf.Variable(initial_value=3.0) + v5 = tf.Variable(initial_value=4.0) + + # v2 through v5 are created in scope and are distributed on parameter servers. + # Default placement is round-robin but the order should not be relied on. + assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0" + assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0" + assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0" + assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0" + ``` + + See `distribute.Strategy.scope` for more information. + + __Variable partitioning__ + + Having dedicated servers to store variables means being able to divide up, or + "shard" the variables across the ps. Partitioning large variable among ps is a + commonly used technique to boost training throughput and mitigate memory + constraints. It enables parallel computations and updates on different shards + of a variable, and often yields better load balancing across parameter servers + . Without sharding, models with large variables (e.g, embeddings) that can't + fit into one machine's memory would otherwise be unable to train. + + With `tf.distribute.experimental.ParameterServerStrategy`, if a + `variable_partitioner` is provided to `__init__` and certain conditions are + satisfied, the resulting variables created in scope are sharded across the + parameter servers, in a round-robin fashion. The variable reference returned + from `tf.Variable` becomes a type that serves as the container of the sharded + variables. One can access `variables` attribute of this container for the + actual variable components. If building model with `tf.Module` or Keras, + the variable components are collected in the `variables` alike attributes. + + + ```python + class Dense(tf.Module): + def __init__(self, name=None): + super().__init__(name=name) + self.w = tf.Variable(tf.random.normal([100, 10]), name='w') + + def __call__(self, x): + return x * self.w + + # Partition the dense layer into 2 shards. + variable_partitioiner = ( + tf.distribute.experimental.partitioners.FixedShardsPartitioner( + num_shards = 2)) + strategy = ParameterServerStrategy(cluster_resolver=..., + variable_partitioner = variable_partitioner) + with strategy.scope(): + dense = Dense() + assert len(dense.variables) == 2 + assert isinstance(dense.variables[0], tf.Variable) + assert isinstance(dense.variables[1], tf.Variable) + assert dense.variables[0].name == "w/part_0" + assert dense.variables[1].name == "w/part_1" + ``` + + The sharded variable container can be converted to a `Tensor` via + `tf.convert_to_tensor`. This means the container can be directly used in most + Python Ops where such `Tensor` convertion automatically happens. For example + in the above code snippet, `x * self.w` would implicitly apply the said tensor + convertion. Note that such convertion can be expensive, as the variable + components need to be transferred from multiple parameter servers to where + the value is used. + + `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor convertion + , and performs parallel lookups on the variable components instead. This is + crutial to scale up embedding lookups when the embedding table variable is + large. + + When a partitioned variable is saved to `SavedModel`, it will be saved as if + it is one single variable. This improves serving efficiency by eliminating + a number of Ops that handle the partiton aspects. + + Known limitations of variable partitioning: + + * Number of parttions must not change across Checkpoint save/load. + + * After saving partitioned variables to a SavedModel, the SavedModel can't be + loaded via `tf.saved_model.load`. + + * Partition variable doesn't directly work with `tf.GradientTape`, please use + the `variables` attributes to get the actual variable components and use + them in gradient APIs instead. + + __Dataset preparation__ + + With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is + created in each of the workers to be used for training. This is done by + creating a `dataset_fn` that takes no argument and returns a + `tf.data.Dataset`, and passing the `dataset_fn` into + `tf.distribute.experimental.coordinator. + ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be + shuffled and repeated to have the examples run through the training as evenly + as possible. + + ```python + def dataset_fn(): + filenames = ... + dataset = tf.data.Dataset.from_tensor_slices(filenames) + + # Dataset is recommended to be shuffled, and repeated. + return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...) + + coordinator = + tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...) + distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) + + ``` + + __Limitations__ + + * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental, + and the API is subject to further changes. + + * `tf.distribute.experimental.ParameterServerStrategy` does not yet support + training with GPU(s). This is a feature request being developed. + + * `tf.distribute.experimental.ParameterServerStrategy` only supports + [custom training loop + API](https://www.tensorflow.org/tutorials/distribute/custom_training) + currently in TF2. Usage of it with Keras `compile`/`fit` API is being + developed. + + * `tf.distribute.experimental.ParameterServerStrategy` must be used with + `tf.distribute.experimental.coordinator.ClusterCoordinator`. """ + # pyformat: disable def __init__(self, cluster_resolver, variable_partitioner=None): - """Initializes the V2 parameter server strategy. + """Initializes the TF2 parameter server strategy. - This also connects to the remote server cluster. + This initializes the `tf.distribute.experimental.ParameterServerStrategy` + object to be ready for use with + `tf.distribute.experimental.coordinator.ClusterCoordinator`. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` object. - variable_partitioner: a callable with the signature `num_partitions = - fn(shape, dtype)`, where `num_partitions` is a list/tuple representing - the number of partitions on each axis, and `shape` and `dtype` are of - types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will - not be partitioned. * `variable_partitioner` will be called for all - variables created under strategy `scope` to instruct how the variables - should be partitioned. Variables will be partitioned if there are more - than one partitions along the partitioning axis, otherwise it falls back - to normal `tf.Variable`. * Only the first / outermost axis partitioning - is supported, namely, elements in `num_partitions` must be 1 other than - the first element. * Partitioner like `min_max_variable_partitioner`, - `variable_axis_size_partitioner` and `fixed_size_partitioner` are also - supported since they conform to the required signature. * Div partition - strategy is used to partition variables. Assuming we assign consecutive - integer ids along the first axis of a variable, then ids are assigned to - shards in a contiguous manner, while attempting to keep each shard size - identical. If the ids do not evenly divide the number of shards, each of - the first several shards will be assigned one more id. For instance, a - variable whose first dimension is - 13 has 13 ids, and they are split across 5 shards as: `[[0, 1, 2], [3, - 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. * Variables created under - `strategy.extended.colocate_vars_with` will not be partitioned, e.g, - optimizer's slot variables. + variable_partitioner: + a `distribute.experimental.partitioners.Partitioner` that specifies + how to partition variables. If `None`, variables will not be + partitioned. + + * Predefined partitioners in `tf.distribute.experimental.partitioners` + can be used for this argument. A commonly used partitioner is + `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`, + which allocates at least 256K per shard, and each ps gets at most one + shard. + + * `variable_partitioner` will be called for each variable created under + strategy `scope` to instruct how the variable should be partitioned. + Variables that have only one partition along the partitioning axis + (i.e., no need for partition) will be created as normal `tf.Variable`. + + * Only the first / outermost axis partitioning is supported. + + * Div partition strategy is used to partition variables. Assuming we + assign consecutive integer ids along the first axis of a variable, then + ids are assigned to shards in a contiguous manner, while attempting to + keep each shard size identical. If the ids do not evenly divide the + number of shards, each of the first several shards will be assigned one + more id. For instance, a variable whose first dimension is 13 has 13 + ids, and they are split across 5 shards as: + `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. + + * Variables created under `strategy.extended.colocate_vars_with` will + not be partitioned. """ + # pyformat: enable self._cluster_resolver = cluster_resolver self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, variable_partitioner) self._verify_args_and_config(cluster_resolver) logging.info( - "ParameterServerStrategyV2 is initialized with cluster_spec: " - "%s", cluster_resolver.cluster_spec()) + "`tf.distribute.experimental.ParameterServerStrategy` is initialized " + "with cluster_spec: %s", cluster_resolver.cluster_spec()) - # TODO(b/167894802): Make chief, worker, and ps names customizable. - self._connect_to_cluster(client_name="chief") + # TODO(b/167894802): Make coordinator, worker, and ps names customizable. + self._connect_to_cluster(coordinator_name="chief") super(ParameterServerStrategyV2, self).__init__(self._extended) distribute_lib.distribution_strategy_gauge.get_cell("V2").set( "ParameterServerStrategy") - def _connect_to_cluster(self, client_name): - if client_name in ["worker", "ps"]: - raise ValueError("Client name should not be 'worker' or 'ps'.") + def _connect_to_cluster(self, coordinator_name): + if coordinator_name in ["worker", "ps"]: + raise ValueError("coordinator name should not be 'worker' or 'ps'.") cluster_spec = self._cluster_resolver.cluster_spec() self._num_workers = len(cluster_spec.as_dict().get("worker", ())) self._num_ps = len(cluster_spec.as_dict().get("ps", ())) device_filters = server_lib.ClusterDeviceFilters() - # For any worker, only the devices on PS and chief nodes are visible + # For any worker, only the devices on ps and coordinator nodes are visible for i in range(self._num_workers): device_filters.set_device_filters( - "worker", i, ["/job:ps", "/job:%s" % client_name]) - # Similarly for any ps, only the devices on workers and chief are visible + "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) + # Similarly for any ps, only the devices on workers and coordinator are + # visible for i in range(self._num_ps): device_filters.set_device_filters( - "ps", i, ["/job:worker", "/job:%s" % client_name]) + "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) # Allow at most one outstanding RPC for each worker at a certain time. This # is to simplify worker failure handling in the runtime @@ -122,7 +461,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): self.__class__.__name__, cluster_spec) remote.connect_to_cluster( cluster_spec, - job_name=client_name, + job_name=coordinator_name, protocol=self._cluster_resolver.rpc_layer, cluster_device_filters=device_filters) @@ -134,7 +473,7 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): def _verify_args_and_config(self, cluster_resolver): if not cluster_resolver.cluster_spec(): raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.") - if self.extended._num_gpus_per_worker > 1: + if self.extended._num_gpus_per_worker > 1: # pylint: disable=protected-access raise NotImplementedError("Multi-gpu is not supported yet.") @@ -205,8 +544,8 @@ class ParameterServerStrategyV2Extended( init_from_fn = False initial_value = initial_value() if not init_from_fn: - # The initial_value is created on client, it will need to be sent to - # PS for variable initialization, which can be inefficient and can + # The initial_value is created on coordinator, it will need to be sent to + # ps for variable initialization, which can be inefficient and can # potentially hit the 2GB limit on protobuf serialization. initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) dtype = initial_value.dtype @@ -248,25 +587,39 @@ class ParameterServerStrategyV2Extended( logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) return initial_value[offsets[shard_index]:offsets[shard_index + 1]] + partition_shape = (offsets[shard_index + 1] - + offsets[shard_index],) + shape[1:] + partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) arg_spec = tf_inspect.getfullargspec(initial_value) if ("shard_info" not in arg_spec.args and "shard_info" not in arg_spec.kwonlyargs): - # `initial_value` is a callable that doesn't accept `shard_info`. - logging.log_if( - logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and - shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) - full_value = initial_value() - return full_value[offsets[shard_index]:offsets[shard_index + 1]] + try: + value = initial_value( + partition_shape=partition_shape, + partition_offset=partition_offset) + except (TypeError, ValueError): + # TypeError: Initializer doesn't accept kwargs + # ValueError: Initializer doesn't accept partition kwargs + # In both cases we go ahead creating the full value and then slice. + value = initial_value() + + if value.shape == partition_shape: + # Initializer supports partition: value is the partition value. + return value + else: + # Initializer doesn't support partition: value is the full value + # and needs to be sliced to get the partition value. + logging.log_if( + logging.WARN, _INEFFICIENT_INIT_WARNING % name, + shard_index == 0 and + shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) + return value[offsets[shard_index]:offsets[shard_index + 1]] else: - # Memory-efficient way of initializing sharded variable. It requires - # the `init_fn` to accept a namedtuple `shard_info`. - component_shape = (offsets[shard_index + 1] - - offsets[shard_index],) + shape[1:] - offsets_all_axes = (offsets[shard_index],) + (0,) * len(shape[1:]) + # For compatibility with `CheckpointInitialValueCallable`. return initial_value( shard_info=trackable.ShardInfo( - shape=tensor_shape.as_shape(component_shape), - offset=offsets_all_axes)) + shape=tensor_shape.as_shape(partition_shape), + offset=partition_offset)) var_list = [] for i in range(num_partitions): @@ -292,6 +645,22 @@ class ParameterServerStrategyV2Extended( self._variable_count += 1 return var + def _experimental_distribute_dataset(self, dataset, options): + if not ops.get_default_graph().building_function: + raise ValueError( + "The `experimental_distribute_dataset` method must be called inside " + "a `tf.function` passed to `create_per_worker_dataset` of " + "`tf.distribute.experimental.coordinator.ClusterCoordinator`") + return dataset + + def _distribute_datasets_from_function(self, dataset_fn, options): + if not ops.get_default_graph().building_function: + raise ValueError( + "The `distribute_datasets_from_function` method must be called " + "inside a `tf.function` passed to `create_per_worker_dataset` of " + "`tf.distribute.experimental.coordinator.ClusterCoordinator`") + return dataset_fn(distribute_lib.InputContext()) + def _call_for_each_replica(self, fn, args, kwargs): with distribute_lib.ReplicaContext( self._container_strategy(), @@ -299,6 +668,11 @@ class ParameterServerStrategyV2Extended( # TODO(rchao): Support multi-replica per worker or sync-group. return distribute_utils.regroup((fn(*args, **kwargs),)) + def _reduce(self, reduce_op, value): + # TODO(rchao): Provide implementation for multi-replica. Also look into why + # the default implementation is not working. + return value + # The warning that will be logged if the way we initialize sharded variables # is memory-inefficient. diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py index d7c447a756f..b097c5961b1 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py @@ -35,7 +35,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops_v2 from tensorflow.python.ops import linalg_ops_impl -from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variables from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.tracking import tracking @@ -77,11 +76,13 @@ class ParameterServerStrategyV2Test(test.TestCase): class PartitionAwareIdentity(object): - def __call__(self, shape, dtype, shard_info): + def __call__(self, shape, dtype, **kwargs): value = linalg_ops_impl.eye(*shape, dtype=dtype) - if shard_info is not None: - value = array_ops.slice(value, shard_info.offset, shard_info.shape) - return value + if "partition_shape" in kwargs and "partition_offset" in kwargs: + return array_ops.slice(value, kwargs["partition_offset"], + kwargs["partition_shape"]) + raise AssertionError("PartitionAwareIdentity do not support " + "non-partitioned initialization") class VariablePartitioningTest(test.TestCase, parameterized.TestCase): @@ -108,7 +109,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testBasic(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) v1 = variables.Variable( @@ -138,7 +139,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testNonCallableInitialValue(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(4)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) with strategy.scope(): v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) @@ -155,7 +156,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testNumPartitionsLargerThanSize(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(4)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) with strategy.scope(): v = variables.Variable([0, 1, 2]) @@ -170,8 +171,8 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testPartitionToOne(self): # For small variables there is only one partition. - variable_partitioner = partitioned_variables.min_max_variable_partitioner( - max_partitions=2, min_slice_size=64 << 20) + variable_partitioner = sharded_variable.MinSizePartitioner( + min_shard_bytes=64 << 20, max_shards=2) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, variable_partitioner) with strategy.scope(): @@ -195,7 +196,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testColocateWith(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): v1 = variables.Variable([0, 1, 2, 3]) @@ -209,9 +210,9 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): self.assertEqual(v2.device, v1.variables[0].device) self.assertAllEqual(v2, [4, 5]) - def testPartitionAwareInitializer(self): + def testCustomPartitionAwareInitializer(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): initializer = PartitionAwareIdentity() initial_value = functools.partial( @@ -228,7 +229,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testPartitionWhenLackOfInfo(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) with strategy.scope(): initializer = init_ops_v2.Constant([0, 1, 2, 3]) # Shape is not explicitly specified. @@ -278,7 +279,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): def testCreateInsideTFFunction(self): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) collection = [] @@ -327,7 +328,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): getter=make_variable) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( - self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2)) + self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint") with strategy.scope(): @@ -342,7 +343,7 @@ class VariablePartitioningTest(test.TestCase, parameterized.TestCase): strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, - partitioned_variables.fixed_size_partitioner(restore_shards)) + sharded_variable.FixedShardsPartitioner(restore_shards)) with strategy.scope(): model2 = Model() diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index 7b757db95cb..553d82e4a26 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -19,64 +19,236 @@ from __future__ import print_function import copy +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.saved_model import save_context from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import dispatch +from tensorflow.python.util.tf_export import tf_export -class ShardedVariable(trackable.Trackable): - """A container for `Variables` that should be treated as shards. +@tf_export('distribute.experimental.partitioners.Partitioner', v1=[]) +class Partitioner(object): + """Partitioner base class: all partitiners inherit from this class. - Variables that are too large to fit on a single device (e.g., large - embeddings) - may need to be sharded over multiple devices. This class maintains a list of - smaller variables that can be independently stored on separate devices (eg, - multiple parameter servers), and saves and restores those variables as if they - were a single larger variable. + Partitioners should implement a `__call__` method with the following + signature: - Objects of this class can be saved with a given number of shards and then - restored from a checkpoint into a different number of shards. - - Objects of this class can be saved to SavedModel format using - `tf.saved_model.save`. The SavedModel can be used by programs like TF serving - APIs. It is not yet supported to load the SavedModel with - `tf.saved_model.load`. - - Since `ShardedVariable` can be saved and then restored to different number of - shards depending on the restore environments, for example, TF serving APIs - would restore to one shard for serving efficiency, when using - `ShardedVariable` in a tf.function, one should generally not assume it has the - same number of shards across save and load. - - Sharding is only supported along the first dimension. - - >>> class Model(tf.Module): - ... def __init__(self): - ... self.sharded_variable = ShardedVariable([ - ... tf.Variable([3.0], dtype=tf.float32), - ... tf.Variable([2.0], dtype=tf.float32) - ... ]) - ... - ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) - ... def fn(self, x): - ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) - ... - ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) - ... def serve_fn(self, x): - ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) - >>> - >>> model = Model() - >>> model.fn(1).numpy() - 2.0 - >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', - ... signatures=model.serve_fn) + ```python + def __call__(self, shape, dtype, axis=0): + # Partitions the given `shape` and returns the partition results. + # See docstring of `__call__` method for the format of partition results. + ``` """ + def __call__(self, shape, dtype, axis=0): + """Partitions the given `shape` and returns the partition results. + + Examples of a partitioner that allocates a fixed number of shards: + + ```python + partitioner = FixedShardsPartitioner(num_shards=2) + partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) + print(partitions) # [2, 0] + ``` + + Args: + shape: a `tf.TensorShape`, the shape to partition. + dtype: a `tf.dtypes.Dtype` indicating the type of the partition value. + axis: The axis to partition along. Default: outermost axis. + + Returns: + A list of integers representing the number of partitions on each axis, + where i-th value correponds to i-th axis. + """ + raise NotImplementedError + + +@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[]) +class FixedShardsPartitioner(Partitioner): + """Partitioner that allocates a fixed number of shards. + + Examples: + + >>> # standalone usage: + >>> partitioner = FixedShardsPartitioner(num_shards=2) + >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) + >>> [2, 1] + >>> + >>> # use in ParameterServerStrategy + >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( + >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) + + """ + + def __init__(self, num_shards): + """Creates a new `FixedShardsPartitioner`. + + Args: + num_shards: `int`, number of shards to partition. + """ + self._num_shards = num_shards + + def __call__(self, shape, dtype, axis=0): + del dtype + result = [1] * len(shape) + result[axis] = min(self._num_shards, shape.dims[axis].value) + return result + + +@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[]) +class MinSizePartitioner(Partitioner): + """Partitioner that allocates a minimum size per shard. + + This partitioner ensures each shard has at least `min_shard_bytes`, and tries + to allocate as many shards as possible, i.e., keeping shard size as small as + possible. The maximum number of such shards (upper bound) is given by + `max_shards`. + + Examples: + + >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) + >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) + >>> [2, 1] + >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) + >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) + >>> [6, 1] + >>> + >>> # use in ParameterServerStrategy + >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( + >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) + """ + + def __init__(self, + min_shard_bytes=256 << 10, + max_shards=1, + bytes_per_string=16): + """Creates a new `MinSizePartitioner`. + + Args: + min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. + max_shards: Upper bound on the number of shards. Defaults to 1. + bytes_per_string: If the partition value is of type string, this provides + an estimate of how large each string is. + """ + if min_shard_bytes < 1: + raise ValueError('min_shard_bytes must be positive, got: %r' % + min_shard_bytes) + if max_shards < 1: + raise ValueError('max_shards must be positive, got: %r' % max_shards) + if bytes_per_string < 1: + raise ValueError('bytes_per_string must be positive, got: %r' % + bytes_per_string) + self._min_shard_bytes = min_shard_bytes + self._max_shards = max_shards + self._bytes_per_string = bytes_per_string + + def __call__(self, shape, dtype, axis=0): + return partitioned_variables.min_max_variable_partitioner( + max_partitions=self._max_shards, + axis=axis, + min_slice_size=self._min_shard_bytes, + bytes_per_string_element=self._bytes_per_string)(shape, dtype) + + +@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) +class MaxSizePartitioner(Partitioner): + """Partitioner that keeps shards below `max_shard_bytes`. + + This partitioner ensures each shard has at most `max_shard_bytes`, and tries + to allocate as few shards as possible, i.e., keeping shard size as large + as possible. + + If the partitioner hits the `max_shards` limit, then each shard may end up + larger than `max_shard_bytes`. By default `max_shards` equals `None` and no + limit on the number of shards is enforced. + + Examples: + + >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) + >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) + >>> [6, 1] + >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) + >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) + >>> [2, 1] + >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) + >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) + >>> [1, 1] + >>> + >>> # use in ParameterServerStrategy + >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( + >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) + """ + + def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): + """Creates a new `MaxSizePartitioner`. + + Args: + max_shard_bytes: The maximum size any given shard is allowed to be. + max_shards: The maximum number of shards in `int` created taking + precedence over `max_shard_bytes`. + bytes_per_string: If the partition value is of type string, this provides + an estimate of how large each string is. + """ + if max_shard_bytes < 1: + raise ValueError('max_shard_bytes must be positive, got: %r' % + max_shard_bytes) + if max_shards and max_shards < 1: + raise ValueError('max_shards must be positive, got: %r' % max_shards) + if bytes_per_string < 1: + raise ValueError('bytes_per_string must be positive, got: %r' % + bytes_per_string) + + self._max_shard_bytes = max_shard_bytes + self._max_shards = max_shards + self._bytes_per_string = bytes_per_string + + def __call__(self, shape, dtype, axis=0): + return partitioned_variables.variable_axis_size_partitioner( + max_shard_bytes=self._max_shard_bytes, + max_shards=self._max_shards, + bytes_per_string_element=self._bytes_per_string, + axis=axis)(shape, dtype) + + +class ShardedVariableSpec(type_spec.TypeSpec): + """Type specification for a `ShardedVariable`.""" + + __slots__ = ['_variable_specs'] + + value_type = property(lambda self: ShardedVariable) + + def __init__(self, *variable_specs): + self._variable_specs = tuple(variable_specs) + + def _serialize(self): + return self._variable_specs + + @property + def _component_specs(self): + return self._variable_specs + + def _to_components(self, value): + return value.variables + + def _from_components(self, variables): + return ShardedVariable(variables) + + +class ShardedVariableMixin(trackable.Trackable): + """Mixin for ShardedVariable.""" + + # TODO(b/170877138): Remove this mixin once fixed. This mixin is required + # since TPUShardedVariable can't be a CompositeTensor. + def __init__(self, variables, name='ShardedVariable'): """Treats `variables` as shards of a larger Variable. @@ -89,17 +261,17 @@ class ShardedVariable(trackable.Trackable): tf.Variable(..., shape=(15, 100), dtype=tf.float32), tf.Variable(..., shape=(5, 100), dtype=tf.float32) ] - sharded_variable = ShardedVariable(variables) + sharded_variable = ShardedVariableMixin(variables) assert sharded_variable.shape.as_list() == [30, 100] ``` Args: variables: A list of `ResourceVariable`s that comprise this sharded variable. Variables should not be shared between different - `ShardedVariable` objects. + `ShardedVariableMixin` objects. name: String. Name of this container. Defaults to "ShardedVariable". """ - super(ShardedVariable, self).__init__() + super(ShardedVariableMixin, self).__init__() self._variables = variables self._name = name @@ -149,6 +321,12 @@ class ShardedVariable(trackable.Trackable): """Return an iterable for accessing the underlying sharded variables.""" return iter(self._variables) + @property + def _type_spec(self): + return ShardedVariableSpec(*( + resource_variable_ops.VariableSpec(v.shape, v.dtype) + for v in self._variables)) + @property def variables(self): """The list of `Variable`s that make up the shards of this object.""" @@ -220,7 +398,63 @@ class ShardedVariable(trackable.Trackable): return obj_map, resource_map +class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): + """A container for `Variables` that should be treated as shards. + + Variables that are too large to fit on a single device (e.g., large + embeddings) + may need to be sharded over multiple devices. This class maintains a list of + smaller variables that can be independently stored on separate devices (eg, + multiple parameter servers), and saves and restores those variables as if they + were a single larger variable. + + Objects of this class can be saved with a given number of shards and then + restored from a checkpoint into a different number of shards. + + Objects of this class can be saved to SavedModel format using + `tf.saved_model.save`. The SavedModel can be used by programs like TF serving + APIs. It is not yet supported to load the SavedModel with + `tf.saved_model.load`. + + Since `ShardedVariable` can be saved and then restored to different number of + shards depending on the restore environments, for example, TF serving APIs + would restore to one shard for serving efficiency, when using + `ShardedVariable` in a tf.function, one should generally not assume it has the + same number of shards across save and load. + + Sharding is only supported along the first dimension. + + >>> class Model(tf.Module): + ... def __init__(self): + ... self.sharded_variable = ShardedVariable([ + ... tf.Variable([3.0], dtype=tf.float32), + ... tf.Variable([2.0], dtype=tf.float32) + ... ]) + ... + ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) + ... def fn(self, x): + ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) + ... + ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) + ... def serve_fn(self, x): + ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) + >>> + >>> model = Model() + >>> model.fn(1).numpy() + 2.0 + >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', + ... signatures=model.serve_fn) + """ + + @property + def _type_spec(self): + return ShardedVariableSpec(*( + resource_variable_ops.VariableSpec(v.shape, v.dtype) + for v in self._variables)) + + def _var_to_tensor(var, dtype=None, name=None, as_ref=False): + """Converts a `ShardedVariable` to a `Tensor`.""" del name if dtype is not None and not dtype.is_compatible_with(var.dtype): raise ValueError( @@ -229,9 +463,40 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): if as_ref: raise NotImplementedError( "ShardedVariable doesn't support being used as a reference.") + # We use op dispatch mechanism to override embedding_lookup ops when called + # with ShardedVariable. This requires embedding_lookup ops to raise TypeError + # when called with ShardedVariable. However since ShardedVariable can be + # converted to a tensor via concat, embedding_lookup ops would silently + # do the convertion and never raise a TypeError. To be able to properly + # raise a TypeError, namescope is used to detect if this method is called + # within a embedding_lookup op. + # NOTE: This doesn't work in eager mode since op namescope is always cleared + # in eager. This also breaks if user sets the name of embedding_lookup op + # with something that doesn't contain str "embedding_lookup". + # + # TODO(chenkai): Find a more robust way to do this, which should not rely + # on namescope. + if 'embedding_lookup' in ops.get_name_scope(): + raise TypeError('Converting ShardedVariable to tensor in embedding lookup' + ' ops is disallowed.') return array_ops.concat(var.variables, axis=0) # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor) + + +# Override the behavior of embedding_lookup(sharded_variable, ...) +@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) +def embedding_lookup(params, + ids, + partition_strategy='mod', + name=None, + validate_indices=True, + max_norm=None): + if isinstance(params, list): + params = params[0] + return embedding_ops.embedding_lookup(params.variables, ids, + partition_strategy, name, + validate_indices, max_norm) diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py index f04e5b248a3..8b88d7b016e 100644 --- a/tensorflow/python/distribute/sharded_variable_test.py +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -24,11 +24,17 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.saved_model import loader @@ -37,6 +43,7 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util +from tensorflow.python.util import nest def _load_and_run( @@ -60,6 +67,39 @@ def _load_and_run( return session.run(output_dict, feed_dict=feed_dict) +class PartitionerTest(test.TestCase): + + def test_fixed_shards_partitioner(self): + partitioner = sharded_variable.FixedShardsPartitioner(num_shards=2) + got = partitioner(tensor_shape.TensorShape([10, 3]), dtypes.float32) + self.assertAllEqual(got, [2, 1]) + + def test_min_size_partitioner(self): + partitioner = sharded_variable.MinSizePartitioner( + min_shard_bytes=4, max_shards=2) + got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) + self.assertAllEqual(got, [2, 1]) + + partitioner = sharded_variable.MinSizePartitioner( + min_shard_bytes=4, max_shards=10) + got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) + self.assertAllEqual(got, [6, 1]) + + def test_max_size_partitioner(self): + partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=4) + got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) + self.assertAllEqual(got, [6, 1]) + + partitioner = sharded_variable.MaxSizePartitioner( + max_shard_bytes=4, max_shards=2) + got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) + self.assertAllEqual(got, [2, 1]) + + partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=1024) + got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) + self.assertAllEqual(got, [1, 1]) + + class ShardedVariableTest(test.TestCase): def test_sharded_variable_simple(self): @@ -286,6 +326,194 @@ class ShardedVariableTest(test.TestCase): full_name='s', full_shape=[2], var_offset=[0], var_shape=[1])) sharded_variable.ShardedVariable([v]) + def test_as_function_input(self): + variables1 = [ + variables_lib.Variable([1]), + variables_lib.Variable([1]), + ] + s = sharded_variable.ShardedVariable(variables1) + variables2 = [ + variables_lib.Variable([2]), + variables_lib.Variable([2]), + ] + s2 = sharded_variable.ShardedVariable(variables2) + + trace_count = [0] + + @def_function.function + def func(sharded_var): + trace_count[0] = trace_count[0] + 1 + sharded_var.assign([0, 0]) + + func(s) + self.assertAllEqual(ops.convert_to_tensor(s), [0, 0]) + self.assertEqual(trace_count[0], 1) + func(s2) + self.assertAllEqual(ops.convert_to_tensor(s2), [0, 0]) + self.assertEqual(trace_count[0], 1) + + def test_flatten(self): + variables = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + ] + s = sharded_variable.ShardedVariable(variables) + + got = nest.flatten(s) + self.assertEqual(s, got[0]) + + got = nest.flatten(s, expand_composites=True) + self.assertAllEqual(variables, got) + + def test_tf_module(self): + + class Model(module.Module): + + def __init__(self): + super().__init__() + variables = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + ] + self.w = sharded_variable.ShardedVariable(variables) + + model = Model() + + self.assertLen(model.variables, 2) + self.assertEqual(model.variables[0], [0]) + self.assertEqual(model.variables[1], [1]) + self.assertAllEqual(model.variables, model.trainable_variables) + + self.assertLen(model._checkpoint_dependencies, 1) + self.assertEqual(model._checkpoint_dependencies[0].ref, model.w) + + def test_keras_layer_setattr(self): + + class Layer(base_layer.Layer): + + def __init__(self): + super().__init__() + variables1 = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + ] + variables2 = [ + variables_lib.Variable([2], trainable=False), + variables_lib.Variable([3], trainable=False), + ] + self.w = sharded_variable.ShardedVariable(variables1) + self.b = sharded_variable.ShardedVariable(variables2) + + layer = Layer() + + self.assertLen(layer.trainable_weights, 2) + self.assertEqual(layer.trainable_weights[0], [0]) + self.assertEqual(layer.trainable_weights[1], [1]) + self.assertLen(layer.non_trainable_weights, 2) + self.assertEqual(layer.non_trainable_weights[0], [2]) + self.assertEqual(layer.non_trainable_weights[1], [3]) + self.assertAllEqual(layer.weights, + layer.trainable_weights + layer.non_trainable_weights) + self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) + self.assertAllEqual(layer.weights, layer.variables) + + checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) + self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) + + def test_keras_layer_add_weight(self): + + class Layer(base_layer.Layer): + + def __init__(self): + super().__init__() + self.w = self.add_weight( + shape=(2,), initializer=lambda shape, dtype: [0, 1], trainable=True) + self.b = self.add_weight( + shape=(2,), + initializer=lambda shape, dtype: [2, 3], + trainable=False) + + def sharded_variable_creator(next_creator, **kwargs): + v1_value = kwargs['initial_value']()[0:1] + v2_value = kwargs['initial_value']()[1:] + + kwargs['initial_value'] = v1_value + kwargs['shape'] = (1,) + v1 = next_creator(**kwargs) + + kwargs['initial_value'] = v2_value + kwargs['shape'] = (1,) + v2 = next_creator(**kwargs) + + return sharded_variable.ShardedVariable([v1, v2]) + + with variable_scope.variable_creator_scope(sharded_variable_creator): + layer = Layer() + + self.assertLen(layer.trainable_weights, 2) + self.assertEqual(layer.trainable_weights[0], [0]) + self.assertEqual(layer.trainable_weights[1], [1]) + self.assertLen(layer.non_trainable_weights, 2) + self.assertEqual(layer.non_trainable_weights[0], [2]) + self.assertEqual(layer.non_trainable_weights[1], [3]) + self.assertAllEqual(layer.weights, + layer.trainable_weights + layer.non_trainable_weights) + self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) + self.assertAllEqual(layer.weights, layer.variables) + + checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) + self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) + + def test_embedding_lookup(self): + v = [ + variables_lib.Variable([[1., 2.], [3., 4.]]), + variables_lib.Variable([[5., 6.], [7., 8.]]), + variables_lib.Variable([[9., 10.]]) + ] + sv = sharded_variable.ShardedVariable(v) + + @def_function.function + def lookup(): + ids = constant_op.constant([0, 3, 4]) + return embedding_ops.embedding_lookup_v2(sv, ids) + + @def_function.function + def sparse_lookup(): + sp_ids = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[0, 3, 4, 1], + dense_shape=[3, 3]) + return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None) + + @def_function.function + def safe_sparse_lookup(): + sp_ids = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[0, -1, 4, 1], + dense_shape=[3, 3]) + sp_weights = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[1., 1., -1., 1.], + dense_shape=[3, 3]) + return embedding_ops.safe_embedding_lookup_sparse_v2( + sv, sp_ids, sp_weights) + + # TODO(chenkai): Add safe_sparse_lookup to the list. Currently + # ShardedVariable is converted to a tensor in safe_sparse_lookup. + for func in [lookup, sparse_lookup]: + num_gather_ops = 0 + for op in func.get_concrete_function().graph.get_operations(): + if op.type == 'ResourceGather': + num_gather_ops += 1 + self.assertEqual( + num_gather_ops, len(v), 'Number of ResourceGather op does not match' + ' expected, possibly due to ShardedVariable accidentally being' + ' converted to tensor in embedding_lookup ops.') + + self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]]) + self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]]) + self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]]) + if __name__ == '__main__': v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/distribute/strategy_gather_test.py b/tensorflow/python/distribute/strategy_gather_test.py index 7cefcf396db..9c70f1d34b3 100644 --- a/tensorflow/python/distribute/strategy_gather_test.py +++ b/tensorflow/python/distribute/strategy_gather_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -74,7 +75,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): lambda _: array_ops.identity(value_on_replica)) def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -133,7 +134,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -155,7 +156,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 1 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -183,7 +184,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -210,11 +211,11 @@ class GatherTest(test.TestCase, parameterized.TestCase): values=[[1., 2.]], indices=[2], dense_shape=dense_shape) def run(value): - return strategy._gather(value, axis=0) + return strategy.gather(value, axis=0) with self.assertRaisesRegex( NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): + r'gather does not support IndexedSlices'): if pure_eager: run(t0) else: @@ -235,7 +236,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): axis = 0 def run(): - return strategy._gather(distributed_values, axis=axis) + return strategy.gather(distributed_values, axis=axis) if not pure_eager: run = def_function.function(run) @@ -271,7 +272,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def replica_fn(per_replica_value): ctx = ds_context.get_replica_context() local_value = array_ops.identity(per_replica_value) - return ctx._all_gather(local_value, axis=axis) + return ctx.all_gather(local_value, axis=axis) if not pure_eager: replica_fn = def_function.function(replica_fn) @@ -342,7 +343,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): @def_function.function def replica_fn(per_replica_value): ctx = ds_context.get_replica_context() - return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis) + return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis) result = strategy.experimental_local_results( strategy.run(replica_fn, args=(next(input_iterator),))) @@ -369,7 +370,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) @@ -397,7 +398,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=1) + return ctx.all_gather(value_identity, axis=1) if not pure_eager: run = def_function.function(run) @@ -436,7 +437,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): value_1 = array_ops.identity(value) value_3 = array_ops.identity(value_2) ctx = ds_context.get_replica_context() - return ctx._all_gather([value_1, value_3], axis=axis) + return ctx.all_gather([value_1, value_3], axis=axis) if not pure_eager: run = def_function.function(run) @@ -455,7 +456,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(): value_identity = array_ops.identity(single_value) ctx = ds_context.get_replica_context() - return ctx._all_gather([value_identity, value_identity], axis=axis) + return ctx.all_gather([value_identity, value_identity], axis=axis) if not pure_eager: run = def_function.function(run) @@ -491,7 +492,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) @@ -519,11 +520,11 @@ class GatherTest(test.TestCase, parameterized.TestCase): def replica_fn(value): ctx = ds_context.get_replica_context() - return ctx._all_gather(value, axis=0) + return ctx.all_gather(value, axis=0) with self.assertRaisesRegex( NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): + r'all_gather does not support IndexedSlices'): if not pure_eager: strategy.run(def_function.function(replica_fn), args=(t0,)) else: @@ -548,7 +549,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) + return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) @@ -573,6 +574,71 @@ class GatherTest(test.TestCase, parameterized.TestCase): r'Dimension \d in both shapes must be equal'): strategy.run(run, args=(per_replica_value,)) + def testAllGatherGradient(self, strategy, pure_eager): + if pure_eager: + self.skipTest('`tf.gradients` is not supported with eager execution ' + 'without using tf.functions.') + + def all_gather_fn(value): + axis = 1 + ctx = ds_context.get_replica_context() + return ctx.all_gather(array_ops.identity(value), axis) + + gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1)) + gradient = [[gradient_comp], [gradient_comp]] + grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy) + + @def_function.function + def step(c): + x = constant_op.constant([[3.], [5.]]) + mid = all_gather_fn(x) + y = mid * c + return gradients_impl.gradients_v2(y, [x])[0] + + def value_fn(ctx): + x = [1., 2., 3., 4., 5., 6., 7., 8.] + return array_ops.constant([x[ctx.replica_id_in_sync_group]]) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + result = strategy.experimental_local_results( + strategy.run(step, args=(per_replica_value,))) + + self.assertAllEqual(grads_for_all_replicas, result) + + def testAllGatherGradientNest(self, strategy, pure_eager): + if pure_eager: + self.skipTest('`tf.gradients` is not supported with eager execution ' + 'without using tf.functions.') + + def all_gather_fn(value): + axis = 1 + ctx = ds_context.get_replica_context() + return ctx.all_gather(array_ops.identity(value), axis) + + gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1)) + gradient = [[gradient_comp], [gradient_comp]] + grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy) + + @def_function.function + def step(c): + x = constant_op.constant([[3.], [5.]]) + y = constant_op.constant([[2.], [4.]]) + mid = all_gather_fn([x, y]) + y = mid * c + return gradients_impl.gradients_v2(y, [x])[0] + + def value_fn(ctx): + x = [1., 2., 3., 4., 5., 6., 7., 8.] + return array_ops.constant([x[ctx.replica_id_in_sync_group]]) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + result = strategy.experimental_local_results( + strategy.run(step, args=(per_replica_value,))) + + self.assertAllEqual(grads_for_all_replicas, result) + def _make_indexed_slices(values, indices, dense_shape): tensor = ops.IndexedSlices( diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 82867edb4c2..2f04b67347f 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -58,7 +58,7 @@ def _gather(strategy, value): return array_ops.stack(value._values) assert len(strategy.extended.worker_devices) == len(value._values) inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] - return strategy._gather(values.PerReplica(inputs), axis=0) + return strategy.gather(values.PerReplica(inputs), axis=0) # pylint: enable=protected-access diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 719147db5ee..e7884c0eef3 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -52,7 +52,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -740,7 +739,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): atexit.register(async_wait) # Flag to turn on VariablePolicy - self._use_var_policy = True + self._use_var_policy = False def _validate_colocate_with_variable(self, colocate_with_variable): distribute_utils. validate_colocate(colocate_with_variable, self) @@ -803,6 +802,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): "distribution function.".format(path, type(spec))) def _experimental_distribute_dataset(self, dataset, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + "`experimental_distribute_datasets_from_function`." + ) if options is None or options.experimental_prefetch_to_device: self._check_spec(dataset.element_spec) @@ -813,6 +819,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): num_replicas_in_sync=self._num_replicas_in_sync) def _distribute_datasets_from_function(self, dataset_fn, options): + if (options and options.experimental_replication_mode == + distribute_lib.InputReplicationMode.PER_REPLICA): + raise NotImplementedError( + "InputReplicationMode.PER_REPLICA " + "is only supported in " + " `experimental_distribute_datasets_from_function` " + "of tf.distribute.MirroredStrategy") input_workers = self._get_input_workers(options) input_contexts = [] num_workers = input_workers.num_workers @@ -1022,8 +1035,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): distribute_utils.TPU_VARIABLE_CLASS_MAPPING, distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) - def _gather_to_implementation(self, value, destinations, axis, - experimental_hints): + def _gather_to_implementation(self, value, destinations, axis, options): if not isinstance(value, values.DistributedValues): return value @@ -1070,7 +1082,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): return output - def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + def _reduce_to(self, reduce_op, value, destinations, options): if (isinstance(value, values.DistributedValues) or tensor_util.is_tensor(value) ) and tpu_values.enclosing_tpu_context() is not None: @@ -1412,12 +1424,11 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): return self.strategy.extended.experimental_logical_device(logical_device_id) # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it. - def _all_gather(self, value, axis, experimental_hints=None): + def all_gather(self, value, axis, experimental_hints=None): del experimental_hints for v in nest.flatten(value): if isinstance(v, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") + raise NotImplementedError("all_gather does not support IndexedSlices") def _all_to_all(value, axis): # The underlying AllToAllOp first do a split of the input value and then @@ -1468,12 +1479,8 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): result = array_ops.concat(squeezed, axis=axis) return result - @custom_gradient.custom_gradient - def grad_wrapper(*xs): - ys = [_all_to_all(t, axis=axis) for t in xs] - return ys, lambda *dy_s: self._all_gather(dy_s, axis) - - return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) + ys = [_all_to_all(t, axis=axis) for t in nest.flatten(value)] + return nest.pack_sequence_as(value, ys) def _set_last_step_outputs(ctx, last_step_tensor_outputs): diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index e0afda84359..239882c1571 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -179,7 +179,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): with ops.device("/device:TPU:0"): self.assertAllEqual(func(), 2.0) - def test_sequential_experimental_runs(self, enable_packed_var): + def test_sequential_runs(self, enable_packed_var): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) @@ -254,8 +254,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return strategy.run(computation) - with self.assertRaisesRegex(errors.InvalidArgumentError, - "TPU compilation failed"): + with self.assertRaises(errors.OpError): compilation_failure_run() @def_function.function @@ -476,7 +475,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(expected_result, run(input_iterator)) self.assertAllEqual((0.,), w.read_value()) - def test_experimental_run_output_on_device(self, enable_packed_var): + def test_run_output_on_device(self, enable_packed_var): strategy = get_tpu_strategy(enable_packed_var) def computation(x): diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py index e54e1878d6d..a38c3c705ea 100644 --- a/tensorflow/python/distribute/v1/cross_device_ops_test.py +++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py @@ -432,7 +432,7 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): NUM_WORKERS = 3 -CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication +CollectiveCommunication = collective_util.CollectiveCommunication class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, @@ -470,15 +470,15 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, devices = ["/device:CPU:0"] if use_strategy_object: + comm_options = collective_util.Options(implementation=communication) strategy = (mwms_lib.CollectiveAllReduceStrategy - ._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access + ._from_local_devices(devices, comm_options)) # pylint: disable=protected-access return strategy, devices, "" else: collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( devices=devices, group_size=len(devices), - collective_keys=collective_keys, - communication=communication) + collective_keys=collective_keys) return collective_all_reduce_ops, devices, "" else: # NCCL requires physical GPUs for every replica, which we can't do with @@ -501,16 +501,16 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, task_type=task_type, task_id=task_id, num_accelerators={"GPU": num_gpus}) + comm_options = collective_util.Options(implementation=communication) strategy = mwms_lib.CollectiveAllReduceStrategy( - cluster_resolver=resolver, communication=communication) + communication_options=comm_options, cluster_resolver=resolver) return (strategy, devices, "grpc://" + self._cluster_spec[task_type][task_id]) else: collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( devices=devices, group_size=len(devices) * NUM_WORKERS, - collective_keys=collective_keys, - communication=communication) + collective_keys=collective_keys) return (collective_all_reduce_ops, devices, "grpc://" + self._cluster_spec[task_type][task_id]) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 1464941523b..a59164bb0d7 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -362,8 +362,23 @@ class DistributedDelegate(DistributedValues): class PerReplica(DistributedValues, composite_tensor.CompositeTensor): """Holds a map from replica to unsynchronized values.""" + def __init__(self, values, type_spec_override=None): + super(PerReplica, self).__init__(values) + # Allow setting a type spec that can be different from the underlying + # values. This allows us avoid retracing for PerReplica from full, partial + # and empty batches. In a multi client setup, we need to avoid such + # retracing otherwise the collectives may mismatch since we assign new + # collective keys when retracing the function. + # + # TODO(b/166169298): remove after CrossDeviceOps is tracing safe. + self._type_spec_override = type_spec_override + @property def _type_spec(self): + if self._type_spec_override is not None: + # Return a deep copy in case the caller changes it, since _type_spec() + # normally returns a temporary object. + return copy.deepcopy(self._type_spec_override) return PerReplicaSpec( *(type_spec.type_spec_from_value(v) for v in self._values)) @@ -874,6 +889,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, Returns: Updated variable or `tf.Operation`. """ + values_util.mark_as_unsaveable() return self.distribute_strategy.extended.update( self, update_fn, args=(value,), kwargs=kwargs, group=True) @@ -1155,6 +1171,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_sub_cross_replica( self, value, read_value=read_value) else: @@ -1167,6 +1184,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_add_cross_replica( self, value, read_value=read_value) else: @@ -1179,6 +1197,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_cross_replica( self, value, read_value=read_value) else: @@ -1243,7 +1262,8 @@ class SyncOnReadVariable(DistributedVariable): # Consider returning a tensor value here to make the return value of # _get_cross_replica consistent. return self._get_replica(0) - + if self._aggregation == vs.VariableAggregation.SUM: + values_util.mark_as_unsaveable() with ds_context.enter_or_assert_strategy(self._distribute_strategy): return self._distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), @@ -1400,9 +1420,10 @@ class OnReadPolicy(VariablePolicy): def _get_cross_replica(self, var): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return var._get_replica(0) # pylint: disable=protected-access - + if self._aggregation == vs.VariableAggregation.SUM: + values_util.mark_as_unsaveable() with ds_context.enter_or_assert_strategy(var.distribute_strategy): - return var.distribute_strategy.reduce( + return var.distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), var, axis=None) @@ -1421,6 +1442,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_sub_cross_replica( var, value, read_value=read_value) else: @@ -1434,6 +1456,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_add_cross_replica( var, value, read_value=read_value) else: @@ -1445,6 +1468,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_cross_replica(var, value, read_value=read_value) else: diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 8a9f0acbd75..1f9bef137d5 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -56,6 +56,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save_context from tensorflow.python.saved_model import save_options from tensorflow.python.training import saver as saver_lib @@ -825,6 +826,67 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): # pylint: enable=g-long-lambda + def testUnsaveable(self, distribution, synchronization, aggregation, mode): + if isinstance(distribution.extended, + parameter_server_strategy.ParameterServerStrategyExtended): + self.skipTest("n/a: not appliable to AggregatingVariable") + if (isinstance(distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + and mode == "graph"): + self.skipTest("MWMS combinations tests do not work well in graph mode.") + with distribution.scope(): + v = variables_lib.Variable([1., 1.], + synchronization=synchronization, + aggregation=aggregation) + + with self.cached_session(): + self.evaluate(variables_lib.global_variables_initializer()) + + export_dir = self.get_temp_dir() + + def _assert_unsaveable(f): + # Ignore if it cannot be traced. Certain combinations are not supported or + # yet or not allowed. + try: + f = def_function.function(f).get_concrete_function() + except (NotImplementedError, ValueError): + return + with self.assertRaisesRegex(ValueError, "f_with_input_signature"): + save.save(v, export_dir, signatures=f) + + _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) + # Reading a ON_READ variable should be unsaveable if either: + # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. + # 2) aggregation is SUM. + if (synchronization == variables_lib.VariableSynchronization.ON_READ and + (aggregation == variables_lib.VariableAggregation.SUM or + (isinstance(distribution.extended, + collective_all_reduce_strategy.CollectiveAllReduceExtended) + and aggregation == variables_lib.VariableAggregation.MEAN))): + _assert_unsaveable(v.read_value) + _assert_unsaveable(v.value) + _assert_unsaveable(lambda: ops.convert_to_tensor(v)) + else: + # Otherwise reading a variable should be saveable. + + @def_function.function + def f(): + v.read_value() + v.value() + return ops.convert_to_tensor(v) + + with self.cached_session(): + save.save(v, export_dir, signatures=f.get_concrete_function()) + @combinations.generate( combinations.combine( diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py index 0071ee67b67..369e2435d9b 100644 --- a/tensorflow/python/distribute/values_util.py +++ b/tensorflow/python/distribute/values_util.py @@ -371,3 +371,23 @@ def is_saving_non_distributed(): options = save_context.get_save_options() return (options.experimental_variable_policy != save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES) + + +def mark_as_unsaveable(): + """Marks the function as unsaveable if not inside save context.""" + if ops.inside_function() and not save_context.in_save_context(): + ops.get_default_graph().mark_as_unsaveable(""" +ConcreteFunction that uses distributed variables in certain way cannot be saved. +If you're saving with + +tf.saved_model.save(..., signatures=f.get_concrete_function()) + +do + +@tf.function(input_signature=...) +def f_with_input_signature(): + ... + +tf.saved_model.save(..., signatures=f_with_input_signature)` + +instead.""") diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index b743d6f736c..845bc669f3a 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -103,7 +103,6 @@ cuda_py_test( "no_oss", # TODO(b/168051787): Enable. "no_pip", # TODO(b/168051787): Enable. ], - tfrt_enabled = True, deps = [ ":pywrap_tensor_test_util", ":test", @@ -179,7 +178,6 @@ cuda_py_test( size = "small", srcs = ["cancellation_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":cancellation", ":test", @@ -264,7 +262,6 @@ cuda_py_test( size = "small", srcs = ["context_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":context", ":test", @@ -288,7 +285,6 @@ cuda_py_test( name = "monitoring_test", srcs = ["monitoring_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":monitoring", ":test", @@ -361,7 +357,6 @@ cuda_py_test( name = "tensor_test", srcs = ["tensor_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":context", ":test", @@ -415,7 +410,6 @@ cuda_py_test( size = "small", srcs = ["core_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":context", ":core", @@ -451,7 +445,6 @@ cuda_py_test( size = "medium", srcs = ["function_defun_collection_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backprop", ":def_function", @@ -567,7 +560,6 @@ cuda_py_test( name = "graph_only_ops_test", srcs = ["graph_only_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "graph_only_ops", "//tensorflow/python:client_testlib", @@ -612,7 +604,6 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/saved_model:save_context", - "//tensorflow/python/saved_model:save_options", "//third_party/py/numpy", "@six_archive//:six", ], @@ -704,7 +695,6 @@ cuda_py_test( name = "benchmarks_test", srcs = ["benchmarks_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backprop", ":benchmarks_test_base", @@ -724,7 +714,6 @@ cuda_py_test( name = "remote_benchmarks_test", srcs = ["remote_benchmarks_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backprop", ":benchmarks_test_base", @@ -750,7 +739,6 @@ tf_py_test( name = "tape_test", srcs = ["tape_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backprop", ":context", @@ -770,7 +758,6 @@ cuda_py_test( name = "ops_test", srcs = ["ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":context", ":execute", @@ -796,7 +783,6 @@ tf_py_test( name = "pywrap_tfe_test", srcs = ["pywrap_tfe_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backprop", ":context", @@ -862,7 +848,6 @@ tf_py_test( size = "medium", srcs = ["lift_to_graph_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "lift_to_graph", "//tensorflow/python:framework_ops", @@ -1072,6 +1057,7 @@ cuda_py_test( shard_count = 8, tags = [ "no_oss", # This test launches local server + "notsan", # TODO(b/170783249) ], deps = [ "//tensorflow/python:array_ops", @@ -1119,6 +1105,7 @@ cuda_py_test( size = "small", srcs = ["device_placement_test.py"], python_version = "PY3", + shard_count = 5, deps = [ ":context", ":def_function", diff --git a/tensorflow/python/eager/benchmarks/resnet50/BUILD b/tensorflow/python/eager/benchmarks/resnet50/BUILD index ccec9f858a2..6c63658e3c7 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/BUILD +++ b/tensorflow/python/eager/benchmarks/resnet50/BUILD @@ -46,7 +46,6 @@ cuda_py_test( "oss_serial", "v1only", ], - tfrt_enabled = True, deps = [ ":resnet50", ":resnet50_test_util", diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py index 52803be7bf9..573c8bc2e10 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py +++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py @@ -336,7 +336,6 @@ class ResNet50Benchmarks(tf.test.Benchmark): defun=False, execution_mode=context.ASYNC) - @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def benchmark_eager_apply_with_defun(self): self._benchmark_eager_apply( 'eager_apply_with_defun', @@ -416,7 +415,6 @@ class ResNet50Benchmarks(tf.test.Benchmark): resnet50_test_util.device_and_data_format(), defun=False) - @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def benchmark_eager_train_datasets_with_defun(self): def make_iterator(tensors): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index bf6e43f5e65..37ab60918c2 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -923,19 +923,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: math_ops.reduce_logsumexp(x) self._run(func, 3000, execution_mode=execution_mode) - @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_CPU(self): self._benchmark_tf_reduce_logsumexp() - @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_CPU_async(self): self._benchmark_tf_reduce_logsumexp(execution_mode=context.ASYNC) - @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_GPU(self): self._benchmark_tf_reduce_logsumexp(device=GPU) - @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_GPU_async(self): self._benchmark_tf_reduce_logsumexp(device=GPU, execution_mode=context.ASYNC) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 8fe158cff93..3e32d05be64 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -414,7 +414,7 @@ class Context(object): raise ValueError( "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) if execution_mode is None: - execution_mode = ASYNC if tfrt_utils.enabled() else SYNC + execution_mode = SYNC self._default_is_async = execution_mode == ASYNC self._lazy_remote_inputs_copy = None self._use_tfrt = tfrt_utils.enabled() @@ -748,7 +748,7 @@ class Context(object): self.ensure_initialized() pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) - def check_collective_ops_peer_health(self, task): + def check_collective_ops_peer_health(self, task, timeout_in_ms): """Check collective peer health. This probes each task to see if they're still alive. Note that restarted @@ -758,6 +758,7 @@ class Context(object): Args: task: a task string, must be in the format of /job:xxx/replica:0/task:N. + timeout_in_ms: an integer, the timeout. If zero, there's no timeout. Raises: tf.errors.UnavailableError: when a peer is down. @@ -766,7 +767,8 @@ class Context(object): tf.errors.InvalidArgumentError: when the task string is invalid. """ self.ensure_initialized() - pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task) + pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, + timeout_in_ms) @property def _handle(self): @@ -948,7 +950,12 @@ class Context(object): if self._log_device_placement is not None: config.log_device_placement = self._log_device_placement - config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled() + is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() + config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled + if (is_mlir_bridge_enabled == + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): + config.experimental.enable_mlir_bridge = True + if self._enable_mlir_graph_optimization is not None: config.experimental.enable_mlir_graph_optimization = ( self._enable_mlir_graph_optimization) diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 1e56b577d96..42af94c6cb1 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -153,7 +153,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): init_fn() self.assertEqual(state[0].numpy(), 2.0) - @test_util.disable_tfrt('Error in native condition op.') def testVariableInitializerNotConstant(self): state = [] @@ -385,7 +384,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): 'defined in another function or code block'): f(array_ops.zeros(shape=(8, 42, 3))) - @test_util.disable_tfrt('b/169375363: error code support') def testRuntimeErrorNotSticky(self): @def_function.function @@ -591,7 +589,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): experimental_variable_policy=save_options.VariablePolicy.NONE)): func_d = func.get_concrete_function(constant_op.constant(2.)) - self.assertIs(func_a, func_c) + self.assertIsNot(func_a, func_c) self.assertIsNot(func_a, func_d) def testInitializationInNestedCall(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 60dd3f17024..9d9cf0b50c3 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -67,7 +67,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import trace from tensorflow.python.saved_model import save_context -from tensorflow.python.saved_model import save_options from tensorflow.python.util import compat from tensorflow.python.util import function_utils from tensorflow.python.util import lazy_loader @@ -1419,13 +1418,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): num_output_tangents) -# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are -# unfortunately too slow to use here. -_POSSIBLE_GRADIENT_TYPES_NONE = 0 -_POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 -_POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 - - class _ForwardBackwardCall(object): """Holds the state of a function call between execution and recording.""" @@ -1919,9 +1911,8 @@ class ConcreteFunction(object): "on invocation of %s, the %d-th input (%s) was not a " "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captured_inputs - possible_gradient_type = ( - pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(args)) - if (possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE + possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) + if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE and executing_eagerly): # No tape is watching; skip to running the function. return self._build_call_outputs(self._inference_function.call( @@ -2081,7 +2072,7 @@ class ConcreteFunction(object): Args: args: A flat list of Tensors with all of the inputs to the forward function (including user-specified and captured inputs). - possible_gradient_type: One of _POSSIBLE_GRADIENT_TYPES_*. + possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. executing_eagerly: Boolean, the value of context.executing_eagerly(). Returns: @@ -2099,7 +2090,8 @@ class ConcreteFunction(object): # Allows re-use of forward and backward function pairs depending on the # tapes and forward accumulators watching its inputs. cache_key = (need_gradients_for_jvps, input_tangents.indices) - if possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_FIRST_ORDER: + if (possible_gradient_type + == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): if input_tangents.indices or executing_eagerly: # There is a single non-persistent tape active, so the user can only # request first-order gradients from a tape. We can spend less time @@ -2130,7 +2122,8 @@ class ConcreteFunction(object): return _ForwardBackwardCall( self._delayed_rewrite_functions, args, input_tangents.tangents, tape_watching=True) - elif possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER: + elif (possible_gradient_type + == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): # Either there's a persistent tape watching, or there are multiple nested # tapes. Either way, the user may request higher-order gradients. We'll # spend a bit more time and make sure higher-order gradients are correct. @@ -2145,7 +2138,7 @@ class ConcreteFunction(object): self._higher_order_tape_functions[cache_key] = functions return _ForwardBackwardCall(functions, args, input_tangents.tangents, tape_watching=True) - # else possible_gradient_type == _POSSIBLE_GRADIENT_TYPES_NONE, meaning no + # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no # tape is recording. return _ForwardBackwardCall( self._delayed_rewrite_functions, args, input_tangents.tangents, @@ -3177,10 +3170,7 @@ class Function(object): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: - # With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior - # in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the - # user saves with it, there's no need to retrace the functions. - variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES + variable_policy = None return (parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 7055ddda437..7ebcf77ff61 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -4297,6 +4297,92 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(TypeError, 'missing required arguments: y'): foo.add(2) # pylint: disable=no-value-for-parameter + def testShapeInferencePropagateConstNestedStack(self): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + ]) + def f(x, s): + old_shape = array_ops.shape(x) + new_shape = array_ops.stack([old_shape[0], s], axis=0) + y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) + return y + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + ]) + def g(x): + y = f(x, s=5) + assert y.shape.as_list() == [3, 5], y.shape.as_list() + return y + + self.assertAllEqual( + g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) + + def testShapeInferencePropagateConstNestedUnstackStack(self): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + ]) + def f(x, s): + s0, _ = array_ops.unstack(array_ops.shape(x), axis=0) + new_shape = array_ops.stack([s0, s], axis=0) + y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) + return y + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) + ]) + def g(x): + y = f(x, s=5) + assert y.shape.as_list() == [3, 5], y.shape.as_list() + return y + + self.assertAllEqual( + g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) + + def testShapeInferencePropagateConstNestedConcat(self): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + ]) + def f(d1, d2, d3): + new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) + y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) + return y + + @def_function.function() + def g(): + y = f(1, 2, 3) + assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() + return y + + self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) + + def testShapeInferencePropagateConstDoubleNested(self): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + tensor_spec.TensorSpec((), dtype=dtypes.int32), + ]) + def f(d1, d2, d3): + new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) + y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) + return y + + @def_function.function() + def g(): + y = def_function.function(f)(1, 2, 3) + assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() + return y + + self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) + @test_util.run_v2_only def testControlDependencyAfterInline(self): v = variables.Variable(0.) @@ -4319,6 +4405,25 @@ class FunctionTest(test.TestCase, parameterized.TestCase): for _ in range(30): f() + @test_util.run_v2_only + def testReadInFuncWriteOutside(self): + # Run many times since we are testing for a potential race condition. + for _ in range(30): + # pylint: disable=cell-var-from-loop + v = variables.Variable(1.) + + @def_function.function + def add_one(): + return v + 1. + + @def_function.function + def get_v_plus_one(): + v_plus_one = add_one() + v.assign_add(2.0) + return v_plus_one + + self.assertAllEqual(get_v_plus_one(), 2.0) + class MultiDeviceTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index 0c8bbe76c98..494abdbf269 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -394,7 +394,6 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ('Tensor', lambda: constant_op.constant(1.3+1j)), ('Variable', lambda: resource_variable_ops.ResourceVariable(1.3+1j))) - @test_util.disable_tfrt('cannot create complex tensor in TFRT.') def testCastToPrimitiveTypesFrom(self, value_fn): x = value_fn() self.assertIsInstance(int(x), int) @@ -482,8 +481,8 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertIs(weak_x(), None) self.assertIs(weak_y(), None) - @test_util.disable_tfrt('TFE_ContextGetExecutorForThread not implemented ' - 'b/156188669') + @test_util.disable_tfrt( + 'b/153697193: tfrt cannot decode python stacktrace yet') def testAsyncExceptionStackTrace(self): config.set_synchronous_execution(False) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 128fb09d114..4bc327adc33 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -3606,16 +3606,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { } TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); - tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { ReturnStatus(status); ReturnOp(ctx, op); }); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { return nullptr; } + tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); + const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); if (op_def == nullptr) return nullptr; diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index aa120973b00..6d67b44a8ea 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import copy +import socket from absl import logging @@ -165,9 +166,12 @@ def connect_to_cluster(cluster_spec_or_resolver, local_port = pywrap_tfe.TF_PickUnusedPortOrDie() job_def = cluster_def.job.add() job_def.name = job_name - # TODO(fishx): Update this to make sure remote worker has valid ip address - # to connect with local. - job_def.tasks[0] = "localhost:{}".format(local_port) + + ipstr = _get_local_ip_address(local_port) + if ipstr: + job_def.tasks[0] = "{}:{}".format(ipstr, local_port) + else: + job_def.tasks[0] = "localhost:{}".format(local_port) server_def = ServerDef( cluster=cluster_def, @@ -221,3 +225,29 @@ def connect_to_cluster(cluster_spec_or_resolver, def _strip_prefix(s, prefix): return s[len(prefix):] if s.startswith(prefix) else s + + +def _get_local_ip_address(port): + """Returns the first local ip address. + + Args: + port: the port used to lookup ip addresses using the socket library. + + Returns: + a string representing the ip address. If it is an IPv6 address, it will be + wrapped by a pair of brackets. Or None if a local ip address cannot be + found. + """ + hostname = socket.gethostname() + addrinfo = socket.getaddrinfo(hostname, port) + # Use the first ip address. + # See the documentation of socket.getaddrinfo here: + # https://docs.python.org/3/library/socket.html#socket.getaddrinfo. + if not addrinfo or not addrinfo[0][4]: + return None + else: + ipstr = addrinfo[0][4][0] + if addrinfo[0][0] == socket.AddressFamily.AF_INET6: + return "[%s]" % ipstr + else: + return ipstr diff --git a/tensorflow/python/eager/remote_cluster_test.py b/tensorflow/python/eager/remote_cluster_test.py index 84dbb11361a..e533ab8577d 100644 --- a/tensorflow/python/eager/remote_cluster_test.py +++ b/tensorflow/python/eager/remote_cluster_test.py @@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): t.start() for _ in range(num_calls): + @def_function.function def worker_fn(i): return math_ops.matmul(i, i) @@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): t1_results = [None] * num_calls t2_results = [None] * num_calls threads = [] - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t1, t1_results))) - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t2, t2_results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t1, t1_results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t2, t2_results))) threads.append(threading.Thread(target=update_server_def_fn)) for t in threads: t.start() @@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): with ops.device(self.device_t2): add = mul + i return add - i + worker_fn.get_concrete_function(x1) num_calls = 10 @@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): with self._coord.stop_on_exception(): for i in range(num_calls): context.update_server_def( - server_def=(self.server_def_s1_s2_s3 - if i % 2 == 0 else self.server_def_s1_s2)) + server_def=(self.server_def_s1_s2_s3 if i % + 2 == 0 else self.server_def_s1_s2)) results = [None] * num_calls threads = [] - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t1, results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t1, results))) threads.append(threading.Thread(target=update_server_def_fn)) for t in threads: t.start() @@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Client for target /job:remote_device/replica:0/task:10 not found."): + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Unable to find worker interface"): context.check_alive("/job:remote_device/replica:0/task:10") diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index ba6a67910b1..3f6ab98605b 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -373,12 +373,14 @@ class AutomaticControlDependencies(object): if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() - # Ensure stateful ops run. Note that this includes read only ops, although - # they don't have direct side effect, they are affected by ops that writes - # the same resource and may be inputs to side-effect ops like tf.print. If - # the function gets inlined, they must execute before ops that depend on - # the function call. - if op_def_registry.get(op.type) is None or op_is_stateful(op): + # Ensure stateful ops run. + # Read-only ops are added to control outputs if the read value is + # consumed. This covers the case when the read value is returned from + # the function since that goes through a tf.identity in mark_as_return. + if (op_def_registry.get(op.type) is None or + (op_is_stateful(op) and + (op.type not in utils.RESOURCE_READ_OPS or + any(output.consumers() for output in op.outputs)))): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py index 8b549263229..a7b238c31b8 100644 --- a/tensorflow/python/framework/auto_control_deps_test.py +++ b/tensorflow/python/framework/auto_control_deps_test.py @@ -107,8 +107,10 @@ class AutomaticControlDependenciesTest(test.TestCase): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: - read_op = gen_resource_variable_ops.read_variable_op( - v.handle, v.dtype).op + read_op = gen_resource_variable_ops.read_variable_op(v.handle, + v.dtype).op + # Read ops get added to control outputs only if they have consumers. + c.mark_as_return(read_op.outputs[0]) self.assertIn(read_op, c.ops_which_must_run) def testVariableMultipleReadsAndWrites(self): @@ -133,6 +135,11 @@ class AutomaticControlDependenciesTest(test.TestCase): v.handle, v + 1) assign_op4 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) + # Read ops get added to control outputs only if they have consumers. + c.mark_as_return(read_op1.outputs[0]) + c.mark_as_return(read_op2.outputs[0]) + c.mark_as_return(read_op3.outputs[0]) + c.mark_as_return(read_op4.outputs[0]) # Verify the control edges. self.assertIn(read_op1, assign_op1.control_inputs) diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 7aed6fb2b70..2691665ffce 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -48,12 +48,11 @@ def enable_tensor_float_32_execution(enabled): This reduced precision should not impact convergence of deep learning models in practice. - TensorFloat-32 is enabled by default in the nightly versions of TensorFlow. We - expect it will remain enabled by default in the first stable version that - TensorFloat-32 is available, which is TensorFlow 2.4, as it increases - performance and does not reduce model quality in practice. If you want to use - the full float32 precision, you can disable TensorFloat-32 execution with this - function. For example: + TensorFloat-32 is enabled by default. TensorFloat-32 is only supported on + Ampere GPUs, so all other hardware will use the full float32 precision + regardless of whether TensorFloat-32 is enabled or not. If you want to use the + full float32 precision on Ampere, you can disable TensorFloat-32 execution + with this function. For example: ```python x = tf.fill((2, 2), 1.0001) @@ -65,28 +64,26 @@ def enable_tensor_float_32_execution(enabled): print(tf.linalg.matmul(x, y)) # [[2.0002, 2.0002], [2.0002, 2.0002]] ``` - There is [an RFC](https://github.com/tensorflow/community/pull/287) proposing - that TensorFloat-32 remain enabled by default in stable versions of - TensorFlow. We expect the RFC to be accepted, but if it isn't, TensorFloat-32 - will be disabled by default in TensorFlow 2.4. - To check whether TensorFloat-32 execution is currently enabled, use `tf.config.experimental.tensor_float_32_execution_enabled`. - Enabling TensorFloat-32 causes float32 inputs of supported ops, such as - `tf.linalg.matmul`, to be rounded from 23 bits of precision to 10 bits of + If TensorFloat-32 is enabled, float32 inputs of supported ops, such as + `tf.linalg.matmul`, will be rounded from 23 bits of precision to 10 bits of precision in most cases. This allows the ops to execute much faster by utilizing the GPU's tensor cores. TensorFloat-32 has the same dynamic range as float32, meaning it is no more likely to underflow or overflow than float32. - Ops still use float32 accumulation when TensorFloat-32 is enabled. Enabling - TensorFloat-32 only affects Ampere GPUs and subsequent GPUs that support - TensorFloat-32. + Ops still use float32 accumulation when TensorFloat-32 is enabled. Enabling or + disabling TensorFloat-32 only affects Ampere GPUs and subsequent GPUs that + support TensorFloat-32. Note TensorFloat-32 is not always used in supported ops, as only inputs of certain shapes are supported. Support for more input shapes and more ops may be added in the future. As a result, precision of float32 ops may decrease in minor versions of TensorFlow. + TensorFloat-32 is also used for some complex64 ops. Currently, TensorFloat-32 + is used in fewer cases for complex64 as it is for float32. + Args: enabled: Bool indicating whether to enable TensorFloat-32 execution. """ diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index a20af802824..7dd26425037 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -214,14 +214,23 @@ class ConfigTest(test.TestCase, parameterized.TestCase): def testEnableMlirBridge(self): # Default value of enable_mlir_bridge is false. self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) # Tests enabling mlir bridge. config.enable_mlir_bridge() self.assertTrue(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED) # Tests disabling mlir bridge. config.disable_mlir_bridge() self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED) @reset_eager def testEnableMlirGraphOptimization(self): diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 3d404a411c9..4f563d18a59 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -139,7 +139,6 @@ cuda_py_test( "no_pip", "no_windows", # b/168218876 ], - tfrt_enabled = True, deps = [ ":_unified_api", ":context_stack", diff --git a/tensorflow/python/framework/experimental/math_ops.cc b/tensorflow/python/framework/experimental/math_ops.cc index 8a6d8525092..5e9522f1999 100644 --- a/tensorflow/python/framework/experimental/math_ops.cc +++ b/tensorflow/python/framework/experimental/math_ops.cc @@ -53,5 +53,27 @@ PYBIND11_MODULE(_math_ops, m) { /*transpose_a=*/false, /*transpose_b=*/false)); return outputs[0]; }); + m.def("neg", + [](AbstractContext* ctx, AbstractTensorHandle* a, const char* name) { + int num_outputs = 1; + std::vector outputs(1); + if (!name) { + name = "Neg"; + } + MaybeRaiseRegisteredFromStatus( + ops::Neg(ctx, {a}, absl::MakeSpan(outputs), name)); + return outputs[0]; + }); + m.def("sub", [](AbstractContext* ctx, AbstractTensorHandle* a, + AbstractTensorHandle* b, const char* name) { + int num_outputs = 1; + std::vector outputs(1); + if (!name) { + name = "Sub"; + } + MaybeRaiseRegisteredFromStatus( + ops::Sub(ctx, {a, b}, absl::MakeSpan(outputs), name)); + return outputs[0]; + }); } } // namespace tensorflow diff --git a/tensorflow/python/framework/experimental/math_ops.py b/tensorflow/python/framework/experimental/math_ops.py index 7b3a171da1f..879cddfa036 100644 --- a/tensorflow/python/framework/experimental/math_ops.py +++ b/tensorflow/python/framework/experimental/math_ops.py @@ -30,3 +30,13 @@ def add(a, b, name=None): def mat_mul(a, b, name=None): ctx = context.get_default() return _math_ops.mat_mul(ctx, a, b, name) + + +def neg(a, name=None): + ctx = context.get_default() + return _math_ops.neg(ctx, a, name) + + +def sub(a, b, name=None): + ctx = context.get_default() + return _math_ops.sub(ctx, a, b, name) diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc index 85c943dddf9..a6975c085ac 100644 --- a/tensorflow/python/framework/experimental/tape.cc +++ b/tensorflow/python/framework/experimental/tape.cc @@ -36,6 +36,8 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR( registry->Register("SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyWithLogitsRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer)); return Status::OK(); } diff --git a/tensorflow/python/framework/experimental/unified_api_test.py b/tensorflow/python/framework/experimental/unified_api_test.py index 3b476255c44..8edb3f51f7a 100644 --- a/tensorflow/python/framework/experimental/unified_api_test.py +++ b/tensorflow/python/framework/experimental/unified_api_test.py @@ -163,6 +163,99 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase): eager_output = model(negative) self.assertAllEqual(eager_output.numpy(), [0.]) + @parameterized.named_parameters([ + ("Graph", False), + ("Mlir", True), + ]) + def testNeg(self, use_mlir): + if use_mlir: + SetTracingImplementation("mlir") + + def model(a): + return unified_math_ops.neg(a) + + with context_lib.set_default(get_immediate_execution_context()): + a = TensorCastHelper(constant_op.constant([2.])) + + func_output = def_function.function(model)(a) + self.assertAllEqual(func_output.numpy(), [-2.]) + + eager_output = model(a) + self.assertAllEqual(eager_output.numpy(), [-2.]) + + @parameterized.named_parameters([ + ("Graph", False), + ("Mlir", True), + ]) + def testNegGrad(self, use_mlir): + if use_mlir: + SetTracingImplementation("mlir") + + def model(a): + with tape_lib.GradientTape() as tape: + tape.watch(a) + result = unified_math_ops.neg(a) + grads = tape.gradient(result, a) + return grads + + with context_lib.set_default(get_immediate_execution_context()): + a = TensorCastHelper(constant_op.constant([2.])) + + func_outputs = def_function.function(model)(a) + self.assertAllEqual(func_outputs.numpy(), [-1.0]) + + eager_outputs = model(a) + self.assertAllEqual(eager_outputs.numpy(), [-1.0]) + + @parameterized.named_parameters([ + ("Graph", False), + ("Mlir", True), + ]) + def testSub(self, use_mlir): + if use_mlir: + SetTracingImplementation("mlir") + + def model(a, b): + return unified_math_ops.sub(a, b) + + with context_lib.set_default(get_immediate_execution_context()): + a = TensorCastHelper(constant_op.constant([1., 2.])) + b = TensorCastHelper(constant_op.constant([3., 4.])) + + func_output = def_function.function(model)(a, b) + self.assertAllEqual(func_output.numpy(), [-2., -2.]) + + eager_output = model(a, b) + self.assertAllEqual(eager_output.numpy(), [-2., -2.]) + + @parameterized.named_parameters([ + ("Graph", False), + ("Mlir", True), + ]) + def testSubGrad(self, use_mlir): + if use_mlir: + SetTracingImplementation("mlir") + + def model(a, b): + with tape_lib.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + result = unified_math_ops.sub(a, b) + grads = tape.gradient(result, [a, b]) + return grads + + with context_lib.set_default(get_immediate_execution_context()): + a = TensorCastHelper(constant_op.constant([1., 2.])) + b = TensorCastHelper(constant_op.constant([3., 4.])) + + func_outputs = def_function.function(model)(a, b) + self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0]) + self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0]) + + eager_outputs = model(a, b) + self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0]) + self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0]) + class UnifiedTapeBenchmark(test.Benchmark): diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 90cd0f62986..ea376077ab7 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase): self.assertAllEqual(y, [[12.0]]) +@test_util.run_all_without_tensor_float_32( + "Calls matmul in custom LSTM function") class UnrollLSTMTest(test.TestCase): BATCH_SIZE = 16 LSTM_DIMS = 32 @@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase): self.assertAllClose(mv0, mv2, rtol=1e-4) self.assertAllClose(mv0, mv3, rtol=1e-4) - @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function") def testUnrollLSTMGrad(self): # Run one step of the unrolled lstm graph. def RunForwardBackward(mode, cfg=None): diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 743d409a1ce..cbcb2652c70 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -159,7 +159,7 @@ def load_library(library_location): library_location) def load_pluggable_device_library(library_location): - """Loads a Tensorflow PluggableDevice plugin + """Loads a Tensorflow PluggableDevice plugin. "library_location" can be a path to a specific shared object, or a folder. If it is a folder, all shared objects will be loaded. when the library is loaded, devices/kernels registered in the library via StreamExecutor C API @@ -176,9 +176,9 @@ def load_pluggable_device_library(library_location): OSError: When the file to be loaded is not found. RuntimeError: when unable to load the library. """ - if file_io.file_exists(library_location): - if file_io.is_directory(library_location): - directory_contents = file_io.list_directory(library_location) + if os.path.exists(library_location): + if os.path.isdir(library_location): + directory_contents = os.listdir(library_location) pluggable_device_libraries = [ os.path.join(library_location, f) for f in directory_contents @@ -188,12 +188,35 @@ def load_pluggable_device_library(library_location): for lib in pluggable_device_libraries: py_tf.TF_LoadPluggableDeviceLibrary(lib) - # Reinitialized physical devices list after plugin registration + # Reinitialized physical devices list after plugin registration. context.context().reinitialize_physical_devices() else: raise OSError( errno.ENOENT, - 'The file or folder to load pluggable device libraries from does not exist.', + 'The file or folder to load pluggable device libraries from does\ + not exist.', library_location) +@tf_export('experimental.register_filesystem_plugin') +def register_filesystem_plugin(plugin_location): + """Loads a TensorFlow FileSystem plugin. + + Args: + plugin_location: Path to the plugin. Relative or absolute filesystem plugin + path to a dynamic library file. + + Returns: + None + + Raises: + OSError: When the file to be loaded is not found. + RuntimeError: when unable to load the library. + """ + if os.path.exists(plugin_location): + py_tf.TF_RegisterFilesystemPlugin(plugin_location) + + else: + raise OSError(errno.ENOENT, + 'The file to load file system plugin from does not exist.', + plugin_location) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ccc1daf721c..47561b2c115 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4332,12 +4332,16 @@ class Graph(object): def _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing=False): with self.colocate_with(op, ignore_existing): - if gradient_uid is not None and self._control_flow_context is not None: - self._control_flow_context.EnterGradientColocation(op, gradient_uid) - try: + if gradient_uid is not None: + ctx = _get_enclosing_context(self) + if ctx is not None: + ctx.EnterGradientColocation(op, gradient_uid) + try: + yield + finally: + ctx.ExitGradientColocation(op, gradient_uid) + else: yield - finally: - self._control_flow_context.ExitGradientColocation(op, gradient_uid) else: yield @@ -6955,3 +6959,15 @@ def set_int_list_attr(op, attr_name, ints): """TF internal method used to set a list(int) attribute in the node_def.""" ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access + + +def _get_enclosing_context(graph): + # pylint: disable=protected-access + if graph is None: + return None + + if graph._control_flow_context is not None: + return graph._control_flow_context + + if graph.building_function and hasattr(graph, "outer_graph"): + return _get_enclosing_context(graph.outer_graph) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index e3fe4c07a57..04b6d90a838 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -326,7 +326,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, False, False, True]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseAndErrors(self): x_int = constant_op.constant(0) @@ -368,7 +367,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, True, True, True]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseOrErrors(self): x_int = constant_op.constant(0) @@ -410,7 +408,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, True, True, False]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseXorErrors(self): x_int = constant_op.constant(0) @@ -450,7 +447,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(y, [True, False]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseNotErrors(self): if context.executing_eagerly(): # :( diff --git a/tensorflow/python/framework/python_api_dispatcher.cc b/tensorflow/python/framework/python_api_dispatcher.cc new file mode 100644 index 00000000000..57a6a9ce94b --- /dev/null +++ b/tensorflow/python/framework/python_api_dispatcher.cc @@ -0,0 +1,220 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/framework/python_api_dispatcher.h" + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" +#include "tensorflow/python/util/util.h" + +namespace tensorflow { + +using ParamInfo = PythonAPIDispatcher::ParamInfo; + +// List of python types to check for dispatch. In most cases, this vector +// will have size zero or one; and sizes greater than 3 should be rare. +using TypeList = absl::InlinedVector; + +namespace { + +// Returns the __tf__dispatch__ attribute of `obj`. +Safe_PyObjectPtr GetAttr_TFDispatch(PyObject* obj) { +#if PY_MAJOR_VERSION < 3 + // Python 2.x: + static PyObject* attr = PyString_InternFromString("__tf_dispatch__"); +#else + // Python 3.x: + static PyObject* attr = PyUnicode_InternFromString("__tf_dispatch__"); +#endif + return Safe_PyObjectPtr(PyObject_GetAttr(obj, attr)); +} + +// Searches `params` for dispatchable types, and returns a vector of borrowed +// references to those types. Removes consecutive duplicates (i.e., if a +// dispatchable parameter has the same type as the previously encountered +// dispatcahble parameter, then it's type is not added again), so the result +// will usually have a length of zero or one; but in the general case, it may be +// longer, and may contain (nonconsecutive) duplicates. +// +// Assumes that `params` is a tuple, and that all parameter indices in +// `dispatch_params` and `dispatch_list_params` are valid. +TypeList FindDispatchTypes(PyObject* params, + const std::vector& dispatchable_params) { + TypeList dispatch_types; + for (const auto& param : dispatchable_params) { + DCHECK_GE(param.index, 0); + DCHECK_LT(param.index, PyTuple_GET_SIZE(params)); + PyObject* value = PyTuple_GET_ITEM(params, param.index); + if (param.is_list) { + DCHECK(PyList_Check(value)); + Py_ssize_t num_items = PyList_Size(value); + for (Py_ssize_t i = 0; i < num_items; ++i) { + PyObject* item = PyList_GET_ITEM(value, i); + // TODO(b/164980194) Consider changing IsDispatchable to not use a + // cache. This may impact efficiency (needs to be measured), but would + // allow us to support monkey-patching classes to be dispatchable. + if (swig::IsDispatchable(item)) { + if (dispatch_types.empty() || + value->ob_type != dispatch_types.back()) { + dispatch_types.push_back(item->ob_type); + } + } + } + } else { + if (swig::IsDispatchable(value)) { + if (dispatch_types.empty() || value->ob_type != dispatch_types.back()) { + dispatch_types.push_back(value->ob_type); + } + } + } + } + + return dispatch_types; +} + +// Removes duplicates from `dispatch_types`, and moves any subtypes to +// before their supertypes. Note: this method is only called when +// `dispatch_types.size() > 1`. +void SortDispatchTypes(TypeList& dispatch_types) { + // Remove duplicates. Note: this is O(n^2) in the number of dispatchable + // types, but we expect this number to be very small in almost every case + // (usually zero, sometimes one, and rarely larger than two). + for (int i = 0; i < dispatch_types.size() - 1; ++i) { + if (dispatch_types[i] == nullptr) continue; + for (int j = i + 1; j < dispatch_types.size(); ++j) { + if (dispatch_types[i] == dispatch_types[j]) { + dispatch_types[j] = nullptr; // mark duplicate + } + } + } + dispatch_types.erase( + std::remove_if(dispatch_types.begin(), dispatch_types.end(), + [](PyTypeObject* t) { return t == nullptr; }), + dispatch_types.end()); + + // Move subclasses before superclasses. As above, this is O(n^2), but we + // expect n to be small. + TypeList sorted; + TypeList subtypes; + for (int i = 0; i < dispatch_types.size(); ++i) { + if (dispatch_types[i] == nullptr) continue; + subtypes.clear(); + for (int j = i + 1; j < dispatch_types.size(); ++j) { + if (dispatch_types[j] == nullptr) continue; + if (PyType_IsSubtype(dispatch_types[j], dispatch_types[i])) { + subtypes.push_back(dispatch_types[j]); + dispatch_types[j] = nullptr; // mark as already added. + } + } + if (!subtypes.empty()) { + std::sort(subtypes.begin(), subtypes.end(), PyType_IsSubtype); + sorted.insert(sorted.end(), subtypes.begin(), subtypes.end()); + } + sorted.push_back(dispatch_types[i]); + } + DCHECK_EQ(dispatch_types.size(), sorted.size()); + dispatch_types.swap(sorted); +} + +} // namespace + +PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name, + PyObject* api_func, int num_params, + bool right_to_left) + : api_name_(PyUnicode_FromStringAndSize(api_name.c_str(), api_name.size())), + api_func_(api_func), + num_params_(num_params), + right_to_left_(right_to_left) { + Py_INCREF(api_func); +} + +bool PythonAPIDispatcher::Initialize( + std::vector dispatchable_params) { + dispatchable_params_.swap(dispatchable_params); + std::sort(dispatchable_params_.begin(), dispatchable_params_.end(), + [](const ParamInfo& a, const ParamInfo& b) -> bool { + return a.index < b.index; + }); + if (right_to_left_) { + std::reverse(dispatchable_params_.begin(), dispatchable_params_.end()); + } + + for (const auto& p : dispatchable_params_) { + if (p.index < 0 || p.index >= num_params_) { + PyErr_SetString( + PyExc_ValueError, + absl::StrCat("PythonAPIDispatcher: dispatchable parameter index out ", + "of range: ", p.index, " not in [0, ", num_params_, ")") + .c_str()); + return false; + } + } + return true; +} + +PyObject* PythonAPIDispatcher::Dispatch(PyObject* params) const { + DCHECK(PyTuple_Check(params)); + + // TODO(b/164980194) Consider removing this check, if the caller is also + // checking/guaranteeing it (once dispatch has been integrated w/ the Python + // API handlers). + if (num_params_ != PyTuple_Size(params)) { +#if PY_MAJOR_VERSION < 3 + // Python 2.x: + Safe_PyObjectPtr api_name_str(PyUnicode_AsUTF8String(api_name_.get())); + if (!api_name_str) return nullptr; + const char* api_name = PyString_AsString(api_name_str.get()); +#else + // Python 3.x: + const char* api_name = PyUnicode_AsUTF8AndSize(api_name_.get(), nullptr); +#endif + PyErr_SetString( + PyExc_TypeError, + absl::StrCat(api_name ? api_name : "unknown PythonAPIDispatcher", + " expected ", num_params_, " parameters, but got ", + PyTuple_Size(params)) + .c_str()); + return nullptr; + } + + TypeList dispatch_types = FindDispatchTypes(params, dispatchable_params_); + + if (dispatch_types.empty()) { + return Py_NotImplemented; + } + + if (dispatch_types.size() > 1) { + SortDispatchTypes(dispatch_types); + } + + for (PyTypeObject* dispatch_type : dispatch_types) { + Safe_PyObjectPtr dispatcher = + GetAttr_TFDispatch(reinterpret_cast(dispatch_type)); + if (!dispatcher) return nullptr; + PyObject* result = PyObject_CallFunctionObjArgs( + dispatcher.get(), api_name_.get(), api_func_.get(), params, nullptr); + if (result != Py_NotImplemented) { + return result; + } + } + + return Py_NotImplemented; +} + +} // namespace tensorflow diff --git a/tensorflow/python/framework/python_api_dispatcher.h b/tensorflow/python/framework/python_api_dispatcher.h new file mode 100644 index 00000000000..7cb3879dd74 --- /dev/null +++ b/tensorflow/python/framework/python_api_dispatcher.h @@ -0,0 +1,131 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ + +#include + +#include +#include + +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" + +namespace tensorflow { + +// Dispatch handler for Python APIs. +// +// A separate PythonAPIDispatcher object is created for each Python API, and +// keeps track of which parameters should be checked for dispatch. +// +// When PythonAPIDispatcher::Dispatch() is called with a tuple of +// canonicalized parameters, it checks the indicated parameters' values for +// `__tf_dispatch__` methods. If found, then this method is called with the +// following arguments: `__tf_dispatch__(api_name, api_func, canon_args)`, +// where: +// +// * `api_name` is the fully-qualified name of the python API (e.g., +// `"tf.math.sum"`). +// * `api_func` is the function that implements the APIs for `Tensor` inputs. +// * `canon_args` is the canonicalized argument list. +// +class PythonAPIDispatcher { + public: + // Information about an API parameter that supports dispatch. `index` is the + // parameter's index in the canonicalized parameter list, and `is_list` is + // true if the parameter expects a list of values (e.g. the `values` parameter + // to `tf.concat`). + struct ParamInfo { + int index; + bool is_list; + }; + + // Constructs a PythonAPIDispatcher. + // + // Args: + // api_name: The fully qualified name of the API handled by this dispatcher. + // api_func: The python function for which implements the API for `Tensor` + // inputs. + // num_params: The number of canonical parameters that the API expects. + // right_to_left: If true, then the normal precedence rules (in which + // dispatchers are tried from left-to-right) are changed to try + // dispatchers from right-to-left instead. This is used for operations + // such as `__radd__`, where the normal parameter order is reversed. + PythonAPIDispatcher(const std::string& api_name, PyObject* api_func, + int num_params, bool right_to_left = false); + + // Initiliaze this PythonAPIDispatcher with information about which parameters + // support dispatch. Returns true on success, or sets a python exception and + // returns false on error. + bool Initialize(std::vector dispatchable_params); + + // Checks if any of the dispatchable parameters have a `__tf_dispatch__` + // method, and if so, calls them. In particular, this method: + // + // 1. Constructs an ordered list of dispatchable types. + // + // * Checks each argument that support dispatch to see if its value(s) have + // a `__tf_dispatch__` method. + // * Arguments are checked left-to-right unless `right_to_left` was set to + // True in the constructor. *Within* a list-valued parameter, elements + // are always checked left-to-right (even if `right_to_left` is True). + // * Duplicate types are removed (only the first occurrence of each type is + // kept). + // * If any type `T_sub` is a subtype of another type `T_super`, but occurs + // after `T_super` in the list of dispatchable types, then it is moved to + // just before `T_super`. + // + // 2. Tries calling each of the dispatchable types' `__tf_dispatch__` methods. + // + // * Dispatch methods are called with the following arguments: + // `__tf_dispatch__(api_name, api_func, canon_args)` + // * Dispatch methods are tried in the order described above. + // * If a dispatch method returns a value, then `Dispatch()` returns a + // new reference to that value. + // * If a dispatch method raises an exception, then `Dispatch()` returns + // null (i.e., propogates the exception). + // * If a dispatch method returns `NotImplemented`, then the dispatcher + // moves on to the next type. + // + // 3. If no dispatchers for found, or all dispatchers returned + // `NotImplemented', then the dispatcher returns a *borrowed* reference + // to `Py_NotImplemented`. + // + // Args: + // params: A `PyTuple` containing the canonicalized parameters to the API. + // All `POSITIONAL_OR_KEYWORD` arguments must be converted to positional + // arguments (`KEYWORD_ONLY` arguments are not currently supported). Any + // dispatchable parameter with `is_list=True` must have been converted to + // `PyList`. + // + // Returns: + // * If a `__tf_dispatch__` handler successfully handled the API: + // Returns a *new* reference to the handler's return value. + // * If no handler was found, or all handlers returned NotImplemented: + // Returns a *borrowed* reference to `Py_NotImplemented`. + // * On error: Sets an exception and returns `nullptr`. + PyObject* Dispatch(PyObject* params) const; + + private: + Safe_PyObjectPtr api_name_; + Safe_PyObjectPtr api_func_; + int num_params_; + std::vector dispatchable_params_; + bool right_to_left_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ diff --git a/tensorflow/python/framework/python_api_dispatcher_test.py b/tensorflow/python/framework/python_api_dispatcher_test.py new file mode 100644 index 00000000000..51dda8a0f9f --- /dev/null +++ b/tensorflow/python/framework/python_api_dispatcher_test.py @@ -0,0 +1,244 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.python_api_dispatcher.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python import _pywrap_python_api_dispatcher +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class Trace(object): + """A dispatchable type that builds traces of ops it's called with.""" + + log = [] + + def __init__(self, api_name, *args): + self.api_name = api_name + self.args = args + + @classmethod + def __tf_dispatch__(cls, api_name, api_func, args): + Trace.log.append("__tf_dispatch__%s" % ((cls.__name__, api_name),)) + if "disabled" in str(args) or api_name == "disabled": + return NotImplemented + del api_func # not used + return cls(api_name, *args) + + def __repr__(self): + return "%s%s" % (type(self).__name__, (self.api_name,) + self.args) + + def __eq__(self, other): + return (type(self) is type(other) and self.api_name == other.api_name and + self.args == other.args) + + +class Trace2(Trace): + pass + + +class Trace2B(Trace2): + pass + + +class Trace3(Trace): + pass + + +class Trace4(Trace): + pass + + +class WeightedTensor(object): + + def __init__(self, tensor, weight): + self.tensor = ops.convert_to_tensor(tensor) + self.weight = weight # Python float + + @classmethod + def __tf_dispatch__(cls, api_name, api_func, args): + del api_name # unused + weights = [arg.weight for arg in args if isinstance(arg, WeightedTensor)] + tensors = [ + arg.tensor if isinstance(arg, WeightedTensor) else arg for arg in args + ] + tensor_result = api_func(*tensors) + avg_weight = sum(weights) / len(weights) + return cls(tensor_result, avg_weight) + + +@test_util.run_all_in_graph_and_eager_modes +class PythonAPIDispatcherTest(test_util.TensorFlowTestCase, + parameterized.TestCase): + + def testNoDispatchableTypes(self): + add_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "tf.math.add", math_ops.add, 2, [0, 1], [], False) + self.assertEqual(add_dispatcher.Dispatch(1, 2), NotImplemented) + + concat_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "tf.concat", array_ops.concat, 2, [1], [0], False) + self.assertEqual(concat_dispatcher.Dispatch([1], 0), NotImplemented) + + def testSimpleDispatchWithTrace(self): + dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "tf.math.add", math_ops.add, 2, [0, 1], [], False) + x = 5 + y = Trace("constant", "y") + z = Trace("constant", "z") + + Trace.log.clear() + self.assertEqual(dispatcher.Dispatch(x, y), Trace("tf.math.add", x, y)) + self.assertEqual(dispatcher.Dispatch(y, x), Trace("tf.math.add", y, x)) + self.assertEqual(dispatcher.Dispatch(y, z), Trace("tf.math.add", y, z)) + self.assertEqual(Trace.log, [ + "__tf_dispatch__('Trace', 'tf.math.add')", + "__tf_dispatch__('Trace', 'tf.math.add')", + "__tf_dispatch__('Trace', 'tf.math.add')" + ]) + + def testDispatcherReturnsNotImplemented(self): + dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "tf.math.add", math_ops.add, 2, [0, 1], [], False) + x = 5 + y = Trace("constant", "disabled") + z = Trace("constant", "z") + + self.assertEqual(dispatcher.Dispatch(x, y), NotImplemented) + self.assertEqual(dispatcher.Dispatch(y, x), NotImplemented) + self.assertEqual(dispatcher.Dispatch(y, z), NotImplemented) + self.assertEqual(dispatcher.Dispatch(z, z), Trace("tf.math.add", z, z)) + + def testSimpleDispatchWithWeightedTensor(self): + dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "tf.math.add", math_ops.add, 2, [0, 1], [], False) + x = 5 + y = WeightedTensor([1, 2, 3], 0.6) + z = WeightedTensor([10, 20, 30], 0.2) + + x_plus_y = dispatcher.Dispatch(x, y) + y_plus_x = dispatcher.Dispatch(y, x) + y_plus_z = dispatcher.Dispatch(y, z) + + self.assertAllEqual(x_plus_y.tensor, [6, 7, 8]) + self.assertAllEqual(y_plus_x.tensor, [6, 7, 8]) + self.assertAllEqual(y_plus_z.tensor, [11, 22, 33]) + + self.assertEqual(x_plus_y.weight, 0.6) + self.assertEqual(y_plus_x.weight, 0.6) + self.assertEqual(y_plus_z.weight, 0.4) + + def testDispatchPrecedence(self): + # We use an API for which dispatch is disabled, so all dispatchers get + # called (since this test checks the order of the dispatcher list). + dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "disabled", None, 5, [0, 1, 4], [2, 3], False) + + t = Trace("constant", "t") + t2_1 = Trace2("constant", "t2_1") + t2_2 = Trace2("constant", "t2_2") + t2b = Trace2B("constant", "t2b") + t3 = Trace3("constant", "t3") + t4 = Trace4("constant", "t4") + + # Three dispatchable types, none of which is a subclass of the other: + # * precedence is left-to-right. + # * duplicates are removed. + Trace.log.clear() + result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4) + self.assertEqual(result, NotImplemented) + self.assertEqual(Trace.log, [ + "__tf_dispatch__('Trace2', 'disabled')", + "__tf_dispatch__('Trace3', 'disabled')", + "__tf_dispatch__('Trace4', 'disabled')" + ]) + + # Subtypes are moved before their base types. + Trace.log.clear() + result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b) + self.assertEqual(result, NotImplemented) + self.assertEqual(Trace.log, [ + "__tf_dispatch__('Trace2B', 'disabled')", + "__tf_dispatch__('Trace2', 'disabled')", + "__tf_dispatch__('Trace3', 'disabled')", + "__tf_dispatch__('Trace4', 'disabled')", + "__tf_dispatch__('Trace', 'disabled')" + ]) + + def testDispatchPrecedenceRightToLeft(self): + # We use an API for which dispatch is disabled, so all dispatchers get + # called (since this test checks the order of the dispatcher list). + dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher( + "disabled", None, 5, [4, 0, 1], [2, 3], True) + + t = Trace("constant", "t") + t2_1 = Trace2("constant", "t2_1") + t2_2 = Trace2("constant", "t2_2") + t2b = Trace2B("constant", "t2b") + t3 = Trace3("constant", "t3") + t4 = Trace4("constant", "t4") + + # Three dispatchable types, none of which is a subclass of the other: + # * precedence is right_to_left (since we set right_to_left=True in the + # PtyonAPIDispatcher constructor). (Note: arguments are scanned + # right-to-left, but the elements of list arguments are still scanned + # left-to-right.) + # * duplicates are removed. + Trace.log.clear() + result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4) + self.assertEqual(result, NotImplemented) + self.assertEqual(Trace.log, [ + "__tf_dispatch__('Trace4', 'disabled')", + "__tf_dispatch__('Trace2', 'disabled')", + "__tf_dispatch__('Trace3', 'disabled')" + ]) + + # Subtypes are moved before their base types. (Note: moving subtypes occurs + # *after* we swap the order to be right-to-left; so the dispatch order here + # is not what we'd get by just reversing the final dispatch order if + # right_to_left were false.) + Trace.log.clear() + result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b) + self.assertEqual(result, NotImplemented) + self.assertEqual(Trace.log, [ + "__tf_dispatch__('Trace2B', 'disabled')", + "__tf_dispatch__('Trace2', 'disabled')", + "__tf_dispatch__('Trace3', 'disabled')", + "__tf_dispatch__('Trace4', 'disabled')", + "__tf_dispatch__('Trace', 'disabled')" + ]) + + def testDispatchParamOutOfRange(self): + with self.assertRaisesRegex(ValueError, "index out of range"): + _pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5, + [0, 1, 5], [2, 3], True) + with self.assertRaisesRegex(ValueError, "index out of range"): + _pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5, + [0, -3], [2, 3], True) + with self.assertRaisesRegex(ValueError, "index out of range"): + _pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5, + [0, 1], [10, 3], True) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/python_api_dispatcher_wrapper.cc b/tensorflow/python/framework/python_api_dispatcher_wrapper.cc new file mode 100644 index 00000000000..4f707a902e2 --- /dev/null +++ b/tensorflow/python/framework/python_api_dispatcher_wrapper.cc @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Note: This library is only used by python_api_dispatcher_test. It is +// not meant to be used in other circumstances. + +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "tensorflow/python/framework/python_api_dispatcher.h" + +namespace py = pybind11; + +namespace { + +tensorflow::PythonAPIDispatcher MakePythonAPIDispatcher( + const std::string& api_name, py::handle api_func, int num_params, + const std::vector& dispatch_params, + const std::vector& dispatch_list_params, bool right_to_left) { + std::vector dispatchable_params; + dispatchable_params.reserve(dispatch_params.size() + + dispatch_list_params.size()); + for (int p : dispatch_params) { + dispatchable_params.push_back({p, false}); + } + for (int p : dispatch_list_params) { + dispatchable_params.push_back({p, true}); + } + + auto dispatcher = tensorflow::PythonAPIDispatcher(api_name, api_func.ptr(), + num_params, right_to_left); + if (!dispatcher.Initialize(dispatchable_params)) { + throw py::error_already_set(); + } + return dispatcher; +} + +py::handle Dispatch(tensorflow::PythonAPIDispatcher* self, py::args args) { + auto result = self->Dispatch(args.ptr()); + if (result == nullptr) { + throw py::error_already_set(); + } else if (result == Py_NotImplemented) { + Py_INCREF(result); + return result; + } else { + return result; + } +} + +} // namespace + +PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) { + py::class_(m, "PythonAPIDispatcher") + .def(py::init(&MakePythonAPIDispatcher)) + .def("Dispatch", Dispatch); +} diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 36bf78987a1..f55cf51062d 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -108,7 +108,7 @@ except Exception: # pylint: disable=broad-except # Uses the same mechanism as above to selectively enable/disable MLIR # compilation. def is_mlir_bridge_enabled(): - return False + return None try: @@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs): If a matmul itself is being tested, or some other op which uses matmul, use `run_without_tensor_float_32` instead. + This also casts complex64 inputs to complex128, since TensorFloat-32 can also + be used with complex64 + Args: a: First input to tf.linalg.matmul b: Second input to tf.linalg.matmul @@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs): b = math_ops.cast(b, "float64") ret = math_ops.matmul(a, b, *args, **kwargs) return math_ops.cast(ret, a.dtype) + elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": + a = math_ops.cast(a, "complex128") + b = math_ops.cast(b, "complex128") + ret = math_ops.matmul(a, b, *args, **kwargs) + return math_ops.cast(ret, a.dtype) else: return math_ops.matmul(a, b, *args, **kwargs) @@ -2022,8 +2030,13 @@ class TensorFlowTestCase(googletest.TestCase): # disable it here. pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) + # Check if the mlir bridge has been explicitly enabled or disabled. If + # is_mlir_bridge_enabled() returns None, the user did not explictly enable + # or disable the bridge so do not update enable_mlir_bridge. if is_mlir_bridge_enabled(): context.context().enable_mlir_bridge = True + elif is_mlir_bridge_enabled() is not None: + context.context().enable_mlir_bridge = False self._threads = [] self._tempdir = None diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 7c8492a5f48..17aa8bf6c11 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -38,7 +38,7 @@ py_library( "//tensorflow/python/keras/datasets", "//tensorflow/python/keras/feature_column", "//tensorflow/python/keras/layers", - "//tensorflow/python/keras/mixed_precision/experimental:mixed_precision_experimental", + "//tensorflow/python/keras/mixed_precision:mixed_precision_experimental", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/premade", "//tensorflow/python/keras/preprocessing", @@ -339,7 +339,6 @@ tf_py_test( size = "small", srcs = ["activations_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":activations", ":backend", @@ -359,7 +358,6 @@ tf_py_test( size = "small", srcs = ["combinations_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":combinations", ":testing_utils", @@ -376,7 +374,6 @@ tf_py_test( size = "small", srcs = ["constraints_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backend", ":combinations", @@ -391,7 +388,6 @@ tf_py_test( size = "small", srcs = ["initializers_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backend", ":combinations", @@ -426,7 +422,6 @@ tf_py_test( python_version = "PY3", shard_count = 8, tags = ["notsan"], - tfrt_enabled = True, deps = [ ":keras", "//tensorflow/python:client_testlib", @@ -630,7 +625,6 @@ tf_py_test( size = "medium", srcs = ["backend_config_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":backend", ":backend_config", @@ -645,7 +639,6 @@ tf_py_test( srcs = ["keras_parameterized_test.py"], python_version = "PY3", tags = ["notsan"], - tfrt_enabled = True, deps = [ ":keras", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index d69930b7455..fba0cf557bc 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -73,9 +73,9 @@ keras_packages = [ "tensorflow.python.keras.layers.wrappers", "tensorflow.python.keras.losses", "tensorflow.python.keras.metrics", - "tensorflow.python.keras.mixed_precision.experimental.get_layer_policy", - "tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer", - "tensorflow.python.keras.mixed_precision.experimental.policy", + "tensorflow.python.keras.mixed_precision.get_layer_policy", + "tensorflow.python.keras.mixed_precision.loss_scale_optimizer", + "tensorflow.python.keras.mixed_precision.policy", "tensorflow.python.keras.models", "tensorflow.python.keras.optimizer_v2.adadelta", "tensorflow.python.keras.optimizer_v2.adagrad", diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index 8140a1ed806..c8151c031d3 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -63,7 +63,6 @@ tf_py_test( "no_rocm", "notsan", # b/168814536 ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", @@ -101,6 +100,7 @@ tf_py_test( tags = [ "no_oss", "no_pip", + "notsan", # TODO(b/170901700) ], deps = [ ":applications", diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index 5887cfca594..67f837dfc58 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -62,20 +62,19 @@ NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5' layers = VersionAwareLayers() -def NASNet( - input_shape=None, - penultimate_filters=4032, - num_blocks=6, - stem_block_filters=96, - skip_reduction=True, - filter_multiplier=2, - include_top=True, - weights=None, - input_tensor=None, - pooling=None, - classes=1000, - default_size=None, - classifier_activation='softmax'): +def NASNet(input_shape=None, + penultimate_filters=4032, + num_blocks=6, + stem_block_filters=96, + skip_reduction=True, + filter_multiplier=2, + include_top=True, + weights='imagenet', + input_tensor=None, + pooling=None, + classes=1000, + default_size=None, + classifier_activation='softmax'): """Instantiates a NASNet model. Reference: diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD index d25afb24f9a..1e249d3febf 100644 --- a/tensorflow/python/keras/benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/BUILD @@ -73,7 +73,6 @@ cuda_py_test( tags = COMMON_TAGS + [ "no_oss_py38", # TODO(b/162044699) ], - tfrt_enabled = True, deps = [ ":profiler_lib", "//tensorflow:tensorflow_py", @@ -85,7 +84,6 @@ cuda_py_test( name = "model_components_benchmarks_test", srcs = ["model_components_benchmarks_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":profiler_lib", "//tensorflow:tensorflow_py", @@ -107,7 +105,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/bidirectional_lstm_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", ":profiler_lib", @@ -120,7 +117,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/text_classification_transformer_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -132,7 +128,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/antirectifier_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -144,7 +139,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/mnist_conv_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -157,7 +151,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/mnist_hierarchical_rnn_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -169,7 +162,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/mnist_irnn_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -181,7 +173,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/reuters_mlp_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -194,7 +185,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/cifar10_cnn_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", @@ -206,7 +196,6 @@ cuda_py_test( srcs = ["keras_examples_benchmarks/mnist_conv_custom_training_benchmark_test.py"], python_version = "PY3", tags = COMMON_TAGS, - tfrt_enabled = True, deps = [ ":distribution_util", "//tensorflow:tensorflow_py", diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD new file mode 100644 index 00000000000..7c3b55c02bd --- /dev/null +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/BUILD @@ -0,0 +1,66 @@ +# Description: +# Implementation of benchmarks on Keras layers. + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "all_py_srcs", + srcs = glob(["*.py"]), + visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"], +) + +BECHMARK_TAGS = [ + "no_oss_py38", # TODO(b/162044699) + "no_pip", # TODO(b/161253163) + "no_windows", # TODO(b/160628318) +] + +# To run CPU benchmarks: +# bazel run -c opt benchmarks_test -- --benchmarks=. + +# To run GPU benchmarks: +# bazel run -c opt --config=cuda benchmarks_test -- \ +# --benchmarks=. + +# To run benchmarks with TFRT: +# bazel run -c opt --config=cuda --test_env=EXPERIMENTAL_ENABLE_TFRT=1 benchmarks_test -- \ +# --benchmarks=. + +# To run a subset of benchmarks using --benchmarks flag. +# --benchmarks: the list of benchmarks to run. The specified value is interpreted +# as a regular expression and any benchmark whose name contains a partial match +# to the regular expression is executed. +# e.g. --benchmarks=".*lstm*." will run all lstm layer related benchmarks. + +py_library( + name = "run_xprof", + srcs = ["run_xprof.py"], + visibility = ["//tensorflow:internal"], +) + +py_library( + name = "layer_benchmarks_test_base", + srcs = ["layer_benchmarks_test_base.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":run_xprof", + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/benchmarks:profiler_lib", + ], +) + +tf_py_test( + name = "layer_benchmarks_test", + srcs = ["layer_benchmarks_test.py"], + python_version = "PY3", + tags = BECHMARK_TAGS, + deps = [ + ":layer_benchmarks_test_base", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py new file mode 100644 index 00000000000..57f2b18e982 --- /dev/null +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py @@ -0,0 +1,128 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks on Keras layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import six + +import tensorflow as tf +from tensorflow.python.keras.benchmarks.layer_benchmarks import layer_benchmarks_test_base +from tensorflow.python.platform import benchmark + + +def _layer_call_backward(layer, x): + with tf.GradientTape() as tape: + y = layer(x) + loss = tf.reduce_mean(y**2) + + _ = tape.gradient(loss, layer.trainable_variables) + + +class KerasLayerBenchmarks(six.with_metaclass( + benchmark.ParameterizedBenchmark, + layer_benchmarks_test_base.LayerBenchmarksBase)): + + _benchmark_parameters = [ + ("Conv2D_small_shape", tf.keras.layers.Conv2D, + {"filters": 1, "kernel_size": 1, "activation": "relu"}, + (1, 1, 1, 1), 10000), + ("Conv2D_normal_shape", tf.keras.layers.Conv2D, + {"filters": 1, "kernel_size": 1, "activation": "relu"}, + (64, 28, 28, 3), 10000), + ("LSTM_small_shape", tf.keras.layers.LSTM, + {"units": 1}, (1, 1, 1), 10000), + ("LSTM_normal_shape", tf.keras.layers.LSTM, + {"units": 4}, (32, 10, 8), 10000), + ] + + def benchmark_layer_call(self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + + fn = functools.partial(layer, x) + self.run_report(fn, num_iters) + + def benchmark_layer_call_with_function( + self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + layer.call = tf.function(layer.call) + + fn = functools.partial(layer, x) + self.run_report(fn, num_iters) + + def benchmark_layer_call_with_xla( + self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + layer.call = tf.function( + layer.call, experimental_compile=True) + + fn = functools.partial(layer, x) + self.run_report(fn, num_iters) + + def benchmark_layer_call_backward( + self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + + fn = functools.partial(_layer_call_backward, layer, x) + self.run_report(fn, num_iters) + + def benchmark_layer_call_backward_with_function( + self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + layer.call = tf.function(layer.call) + + fn = functools.partial(_layer_call_backward, layer, x) + self.run_report(fn, num_iters) + + +class KerasLayerBenchmarksBackwardXLA(six.with_metaclass( + benchmark.ParameterizedBenchmark, + layer_benchmarks_test_base.LayerBenchmarksBase)): + + _benchmark_parameters = [ + ("Conv2D_small_shape", tf.keras.layers.Conv2D, + {"filters": 1, "kernel_size": 1, "activation": "relu"}, + (1, 1, 1, 1), 10000), + ("Conv2D_normal_shape", tf.keras.layers.Conv2D, + {"filters": 1, "kernel_size": 1, "activation": "relu"}, + (64, 28, 28, 3), 10000), + # TODO(b/153480400) + # ("LSTM_small_shape", tf.keras.layers.LSTM, + # {"units": 1}, (1, 1, 1), 10000), + # ("LSTM_normal_shape", tf.keras.layers.LSTM, + # {"units": 4}, (32, 10, 8), 10000), + ] + + def benchmark_layer_call_backward_with_xla( + self, layer_cls, layer_args, input_shape, num_iters): + layer = layer_cls(**layer_args) + x = tf.ones(input_shape) + layer.call = tf.function( + layer.call, experimental_compile=True) + + fn = functools.partial(_layer_call_backward, layer, x) + self.run_report(fn, num_iters) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py new file mode 100644 index 00000000000..94595c95449 --- /dev/null +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test_base.py @@ -0,0 +1,72 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Benchmark base to run and report Keras layers benchmark results.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import tensorflow as tf + +from tensorflow.python.keras.benchmarks.layer_benchmarks import run_xprof + + +class LayerBenchmarksBase(tf.test.Benchmark): + """Run and report benchmark results. + + The first run is without any profiling to purly measure running time. + Second run is with xprof but no python trace. + Third run is with xprof and python trace. + Note: xprof runs fewer iterations, and the maximum iterations is 100. + """ + + def run_report(self, func, num_iters): + """Run and report benchmark results for different settings.""" + + # 0. Warm up. + func() + + # 1. Run without profiling. + start = time.time() + for _ in range(num_iters): + func() + total_time = time.time() - start + us_mean_time = total_time * 1e6 / num_iters + + metrics = [ + {"name": "examples_per_sec", + "value": float("{0:.3f}".format(num_iters / total_time))}, + {"name": "us_per_example", + "value": float("{0:.3f}".format(us_mean_time))}] + + # 2. Run with xprof with no python trace. + num_iters_xprof = min(100, num_iters) + xprof_link, us_per_example = run_xprof.run_with_xprof( + func, num_iters_xprof, False) + # This xprof link will appear in the benchmark dashboard. + extras = { + "xprof_link": xprof_link, + "us_per_example_with_xprof": us_per_example + } + + # 3. Run with xprof and python trace. + xprof_link, us_per_example = run_xprof.run_with_xprof( + func, num_iters_xprof, True) + extras["xprof_with_python_trace"] = xprof_link + extras["us_per_example_with_xprof_and_python"] = us_per_example + + self.report_benchmark( + iters=num_iters, wall_time=us_mean_time, extras=extras, metrics=metrics) diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py new file mode 100644 index 00000000000..aef4d7b9877 --- /dev/null +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/run_xprof.py @@ -0,0 +1,40 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import time +import uuid + +from tensorflow.python.profiler import profiler_v2 as profiler + +def run_with_xprof(self, func, num_iters_xprof=100, enable_python_trace=True, + logdir='/tmp/layer_benchmark_xprof/'): + suid = str(uuid.uuid4()) + if enable_python_trace: + options = profiler.ProfilerOptions(python_tracer_level=1) + logdir = os.path.join(logdir, str(uuid.uuid4()) + "_with_python") + else: + options = profiler.ProfilerOptions(python_tracer_level=0) + logdir = os.path.join(logdir, suid) + + start = time.time() + with profiler.Profile(logdir, options): + for _ in range(num_iters_xprof): + func() + total_time = time.time() - start + us_per_example = float("{0:.3f}".format(total_time * 1e6 / num_iters_xprof)) + return logdir, us_per_example diff --git a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD index 5501aedcd4e..c99d062143f 100644 --- a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD @@ -43,7 +43,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -58,7 +57,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -73,7 +71,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -88,7 +85,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -103,7 +99,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -118,7 +113,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -133,7 +127,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", @@ -148,7 +141,6 @@ cuda_py_test( "no_pip", # b/161253163 "no_windows", # b/160628318 ], - tfrt_enabled = True, deps = [ ":saved_model_benchmark_util", "//tensorflow:tensorflow_py", diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 79c36c0f559..3cb36019dab 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -49,8 +49,8 @@ py_library( "//tensorflow/python/keras:losses", "//tensorflow/python/keras:optimizers", "//tensorflow/python/keras:regularizers", - "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/keras/mixed_precision:autocast_variable", + "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/keras/saving", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:mode_keys", @@ -136,7 +136,6 @@ cuda_py_test( srcs = ["worker_training_state_test.py"], python_version = "PY3", shard_count = 4, - tfrt_enabled = True, deps = [ ":multi_worker_testing_utils", ":worker_training_state", @@ -150,7 +149,6 @@ cuda_py_test( distribute_py_test( name = "checkpointing_test", srcs = ["checkpointing_test.py"], - disable_mlir_bridge = False, main = "checkpointing_test.py", tags = [ "multi_and_single_gpu", @@ -171,6 +169,7 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "nomsan", # TODO(b/162894966) + "notsan", # TODO(b/171040408): data race ], # b/155301154 broken with XLA:GPU xla_enable_strict_auto_jit = True, @@ -197,8 +196,8 @@ cuda_py_test( "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers", - "//tensorflow/python/keras/mixed_precision/experimental:policy", - "//tensorflow/python/keras/mixed_precision/experimental:test_util", + "//tensorflow/python/keras/mixed_precision:policy", + "//tensorflow/python/keras/mixed_precision:test_util", "//tensorflow/python/ops/losses", "@absl_py//absl/testing:parameterized", ], @@ -251,9 +250,11 @@ distribute_py_test( main = "custom_training_loop_models_test.py", tags = [ "multi_and_single_gpu", + "notsan", # TODO(b/170954243) ], tpu_tags = [ "no_oss", # b/153615544. + "notsan", # TODO(b/170869466) ], deps = [ "//tensorflow/python:math_ops", @@ -339,11 +340,13 @@ distribute_py_test( full_precision = True, main = "distribute_strategy_test.py", python_version = "PY3", - shard_count = 10, + shard_count = 20, tags = [ "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", + "noasan", # TODO(b/170902997) + "notap", # TODO(b/170902997) "notsan", ], tpu_tags = [ @@ -413,6 +416,7 @@ distribute_py_test( "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", + "nogpu", # TODO(b/170905292) "notsan", ], deps = [ @@ -429,7 +433,9 @@ distribute_py_test( main = "keras_embedding_model_correctness_test.py", shard_count = 8, tags = [ + "broken", # b/170975619 "multi_and_single_gpu", + "no_rocm", "no_windows_gpu", "notsan", ], @@ -450,6 +456,7 @@ distribute_py_test( "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", + "noasan", # TODO(b/337374867) fails with -fsanitize=null "notsan", ], xla_enable_strict_auto_jit = False, # Tensorflow also fails. @@ -505,9 +512,8 @@ distribute_py_test( shard_count = 31, tags = [ "multi_and_single_gpu", - "no_cuda11", - "no_oss", "no_windows_gpu", + "noasan", # TODO(b/337374867) fails with -fsanitize=null "notpu", # TODO(b/153672562) "notsan", ], @@ -520,7 +526,6 @@ distribute_py_test( name = "keras_save_load_test", size = "medium", srcs = ["keras_save_load_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "keras_save_load_test.py", shard_count = 7, @@ -726,6 +731,7 @@ py_test( srcs = ["multi_worker_callback_tf2_test.py"], python_version = "PY3", shard_count = 5, + tags = ["no_oss_py38"], #TODO(b/171435331) deps = [ "//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:combinations", @@ -790,7 +796,6 @@ distribute_py_test( name = "saved_model_save_load_test", size = "medium", srcs = ["saved_model_save_load_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "saved_model_save_load_test.py", shard_count = 7, @@ -808,7 +813,6 @@ distribute_py_test( name = "saved_model_mixed_api_test", size = "medium", srcs = ["saved_model_mixed_api_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "saved_model_mixed_api_test.py", shard_count = 7, @@ -861,11 +865,12 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:sharded_variable", - "//tensorflow/python/distribute/client", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/distribute/coordinator:cluster_coordinator", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", diff --git a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py index f22a954d78d..71830f4d83a 100644 --- a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py @@ -42,8 +42,8 @@ from tensorflow.python.keras import layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import training -from tensorflow.python.keras.mixed_precision.experimental import policy -from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.keras.mixed_precision import policy +from tensorflow.python.keras.mixed_precision import test_util as mp_test_util from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn @@ -309,51 +309,53 @@ class LocalCollectiveAllReduceStrategy( with policy.policy_scope('mixed_float16'): self._test_mixed_precision(None, None, required_gpus) +# TODO(b/170360740): Timeout in OSS +if not multi_process_runner.is_oss(): -@ds_combinations.generate( - combinations.combine( - strategy=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - ], - mode=['eager'])) -class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase, - parameterized.TestCase): + @ds_combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ], + mode=['eager'])) + class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase, + parameterized.TestCase): - def testFitWithoutStepsPerEpochPartialBatch(self, strategy): + def testFitWithoutStepsPerEpochPartialBatch(self, strategy): - def _model_fn(): - x = layers.Input(shape=(1,), name='input') - y = layers.Dense(1, name='dense')(x) - model = training.Model(x, y) - return model + def _model_fn(): + x = layers.Input(shape=(1,), name='input') + y = layers.Dense(1, name='dense')(x) + model = training.Model(x, y) + return model - def _get_dataset(): - inputs = array_ops.expand_dims_v2(constant_op.constant(range(10)), axis=1) - targets = array_ops.expand_dims_v2( - constant_op.constant(range(10)), axis=1) - # Make global batch size 12 for 2 replicas and a non-repeated dataset with - # 10 elements so that we have partial batch - dataset = dataset_ops.Dataset.from_tensor_slices( - (inputs, targets)).batch(12, drop_remainder=False) - return dataset + def _get_dataset(): + inputs = array_ops.expand_dims_v2( + constant_op.constant(range(10)), axis=1) + targets = array_ops.expand_dims_v2( + constant_op.constant(range(10)), axis=1) + # Make global batch size 12 for 2 replicas and a non-repeated dataset + # with 10 elements so that we have partial batch + dataset = dataset_ops.Dataset.from_tensor_slices( + (inputs, targets)).batch( + 12, drop_remainder=False) + return dataset + + with strategy.scope(): + optimizer_fn = gradient_descent_keras.SGD + optimizer = optimizer_fn(0.001) + model = _model_fn() + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + dataset = _get_dataset() + kernel_before = model.get_weights()[0][0] + model.fit(dataset, epochs=10) + kernel_after = model.get_weights()[0][0] + self.assertNotEqual(kernel_before, kernel_after) + self.assertGreater(abs(kernel_before - 1), abs(kernel_after - 1)) - with strategy.scope(): - optimizer_fn = gradient_descent_keras.SGD - optimizer = optimizer_fn(0.001) - model = _model_fn() - loss = 'mse' - metrics = ['mae'] - model.compile( - optimizer, - loss, - metrics=metrics) - dataset = _get_dataset() - kernel_before = model.get_weights()[0][0] - model.fit(dataset, epochs=10) - kernel_after = model.get_weights()[0][0] - self.assertNotEqual(kernel_before, kernel_after) - self.assertGreater(abs(kernel_before-1), abs(kernel_after-1)) if __name__ == '__main__': v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 7bc70101fb4..e18eba3ae5a 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.distribute import central_storage_strategy -from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy @@ -52,7 +51,7 @@ from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute import distributed_training_utils_v1 from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.engine import base_layer_utils -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.keras.utils import np_utils from tensorflow.python.ops import array_ops @@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase, if mode == 'graph' and _is_tpu_strategy(distribution): self.skipTest('partial batch not supported with TPU in graph mode.') - if isinstance(distribution, - collective_all_reduce_strategy.CollectiveAllReduceStrategy): - self.skipTest('EOF error causes subsequent collective ops fail.') with self.cached_session(): with distribution.scope(): optimizer_fn = gradient_descent_keras.SGD @@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, loss, metrics=metrics) - inputs = np.zeros((1000, 3), dtype=np.float32) - targets = np.zeros((1000, 4), dtype=np.float32) + inputs = np.zeros((100, 3), dtype=np.float32) + targets = np.zeros((100, 4), dtype=np.float32) # steps/steps_per_epoch are calculated when using numpy arrays as # input data. fit_with_numpy = model.fit( @@ -2453,12 +2449,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, task_type='worker', task_id=1, num_accelerators={'GPU': 0}) - distribution = parameter_server_strategy.ParameterServerStrategy( + distribution = parameter_server_strategy.ParameterServerStrategyV1( cluster_resolver) self.assertIsInstance(distribution, - (parameter_server_strategy.ParameterServerStrategyV1, - parameter_server_strategy.ParameterServerStrategy)) + parameter_server_strategy.ParameterServerStrategyV1) with self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*'): diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 77a5f290439..f40f45cccbb 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -32,7 +32,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.keras.distribute import distributed_training_utils -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.preprocessing import sequence from tensorflow.python.platform import test from tensorflow.python.util import nest diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py index 7e00ff5ec7f..43092fc2191 100644 --- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py @@ -28,7 +28,7 @@ from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.layers import recurrent as rnn_v1 from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2 -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index ea4b349b1cf..fcef87faa9c 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -183,6 +183,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): def proc_model_checkpoint_works_with_same_file_path(test_obj, saving_filepath): + if multi_process_runner.is_oss(): + test_obj.skipTest('TODO(b/170838633): Failing in OSS') model, _, train_ds, steps = _model_setup(test_obj, file_format='') num_epoch = 4 @@ -205,7 +207,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): raise multi_process_runner.get_barrier().wait() - backup_filepath = os.path.join(bar_dir, 'checkpoint') + backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint') test_obj.assertTrue(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py index e4801d909ec..503dd68eb71 100644 --- a/tensorflow/python/keras/distribute/parameter_server_training_test.py +++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for ParameterServerClient and Keras models.""" +"""Tests for ClusterCoordinator and Keras models.""" from __future__ import absolute_import from __future__ import division @@ -21,14 +21,16 @@ from __future__ import print_function import random import tempfile +from absl.testing import parameterized from tensorflow.python import keras from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.client import client as client_lib from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -44,7 +46,15 @@ from tensorflow.python.platform import test from tensorflow.python.training.server_lib import ClusterSpec -def make_client(num_workers, num_ps): +# These vocabularies usually come from TFT or a Beam pipeline. +FEATURE_VOCAB = [ + "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong", + "wonder_woman" +] +LABEL_VOCAB = ["yes", "no"] + + +def make_coordinator(num_workers, num_ps): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") cluster_def["chief"] = [ @@ -52,68 +62,71 @@ def make_client(num_workers, num_ps): ] cluster_resolver = SimpleClusterResolver( ClusterSpec(cluster_def), rpc_layer="grpc") - return client_lib.Client( + return coordinator_lib.ClusterCoordinator( parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)) -class KPLTest(test.TestCase): +class KPLTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): super(KPLTest, cls).setUpClass() - cls.client = make_client(num_workers=3, num_ps=2) - - def testTrainAndServe(self): - # These vocabularies usually come from TFT or a Beam pipeline. - feature_vocab = [ - "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong", - "wonder_woman" - ] - label_vocab = ["yes", "no"] - - with self.client.strategy.scope(): - - # Define KPLs under strategy's scope. Right now, if they have look up - # tables, they will be created on the client. Their variables will be - # created on PS. Ideally they should be cached on each worker since they - # will not be changed in a training step. - feature_lookup_layer = string_lookup.StringLookup() - raw_feature_input = keras.layers.Input( - shape=(3,), dtype=dtypes.string, name="feature", ragged=True) - feature_id_input = feature_lookup_layer(raw_feature_input) - - # Model creates variables as well. - feature_ps = keras.Model({"features": raw_feature_input}, - feature_id_input) - - # TODO(yuefengz): adapt may be expensive for large vocab? - feature_lookup_layer.adapt(feature_vocab) + cls.coordinator = make_coordinator(num_workers=3, num_ps=2) + def define_kpls_for_training(self, use_adapt): + # Define KPLs under strategy's scope. Right now, if they have look up + # tables, they will be created on the client. Their variables will be + # created on PS. Ideally they should be cached on each worker since they + # will not be changed in a training step. + if use_adapt: + feature_lookup_layer = string_lookup.StringLookup(num_oov_indices=1) + feature_lookup_layer.adapt(FEATURE_VOCAB) label_lookup_layer = string_lookup.StringLookup( num_oov_indices=0, mask_token=None) - raw_label_input = keras.layers.Input( - shape=(), dtype=dtypes.string, name="label") - label_id_input = label_lookup_layer(raw_label_input) - label_ps = keras.Model({"label": raw_label_input}, label_id_input) + label_lookup_layer.adapt(LABEL_VOCAB) + else: + feature_lookup_layer = string_lookup.StringLookup( + vocabulary=FEATURE_VOCAB, num_oov_indices=1) + label_lookup_layer = string_lookup.StringLookup( + vocabulary=LABEL_VOCAB, num_oov_indices=0, mask_token=None) - label_lookup_layer.adapt(label_vocab) + raw_feature_input = keras.layers.Input( + shape=(3,), dtype=dtypes.string, name="feature", ragged=True) + feature_id_input = feature_lookup_layer(raw_feature_input) - # Only needed for serving. - label_inverse_lookup_layer = string_lookup.StringLookup( - num_oov_indices=1, - mask_token=None, - vocabulary=label_lookup_layer.get_vocabulary(), - invert=True) + # Model creates variables as well. + feature_ps = keras.Model({"features": raw_feature_input}, feature_id_input) + + raw_label_input = keras.layers.Input( + shape=(), dtype=dtypes.string, name="label") + label_id_input = label_lookup_layer(raw_label_input) + label_ps = keras.Model({"label": raw_label_input}, label_id_input) + + return feature_ps, label_ps + + def define_reverse_lookup_layer(self): + # Only needed for serving. + label_inverse_lookup_layer = string_lookup.StringLookup( + num_oov_indices=1, mask_token=None, vocabulary=LABEL_VOCAB, invert=True) + return label_inverse_lookup_layer + + @combinations.generate( + combinations.combine(mode=["eager"], use_adapt=[True, False])) + def testTrainAndServe(self, use_adapt): + + with self.coordinator.strategy.scope(): + + feature_ps, label_ps = self.define_kpls_for_training(use_adapt) def dataset_fn(): def feature_and_label_gen(): while True: - features = random.sample(feature_vocab, 3) + features = random.sample(FEATURE_VOCAB, 3) label = "yes" if "avenger" in features else "no" yield {"features": features, "label": label} - # The dataset will be created on the client? + # The dataset will be created on the coordinator? raw_dataset = dataset_ops.Dataset.from_generator( feature_and_label_gen, output_types={ @@ -131,25 +144,30 @@ class KPLTest(test.TestCase): }, [x["label"]])) return train_dataset - distributed_dataset = self.client.create_per_worker_dataset(dataset_fn) + distributed_dataset = self.coordinator.create_per_worker_dataset( + dataset_fn) + # Create the model. The input needs to be compatible with KPLs. model_input = keras.layers.Input( shape=(3,), dtype=dtypes.int64, name="model_input") + + # input_dim includes a mask token and an oov token. emb_output = keras.layers.Embedding( - input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=20)( + input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)( model_input) emb_output = math_ops.reduce_mean(emb_output, axis=1) dense_output = keras.layers.Dense( units=1, activation="sigmoid")( emb_output) model = keras.Model({"features": model_input}, dense_output) + optimizer = rmsprop.RMSprop(learning_rate=0.01) accuracy = keras.metrics.Accuracy() @def_function.function def worker_fn(iterator): - def train_step(iterator): + def replica_fn(iterator): batch_data, labels = next(iterator) with backprop.GradientTape() as tape: pred = model(batch_data, training=True) @@ -163,18 +181,18 @@ class KPLTest(test.TestCase): actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64) accuracy.update_state(labels, actual_pred) - self.client._strategy.run(train_step, args=(iterator,)) + self.coordinator._strategy.run(replica_fn, args=(iterator,)) distributed_iterator = iter(distributed_dataset) for _ in range(10): - self.client.schedule(worker_fn, args=(distributed_iterator,)) - self.client.join() + self.coordinator.schedule(worker_fn, args=(distributed_iterator,)) + self.coordinator.join() self.assertGreater(accuracy.result().numpy(), 0.0) # Create a saved model. model.feature_ps = feature_ps model.label_ps = label_ps - model.label_inverse_lookup_layer = label_inverse_lookup_layer + model.label_inverse_lookup_layer = self.define_reverse_lookup_layer() def create_serving_signature(model): diff --git a/tensorflow/python/keras/distribute/worker_training_state.py b/tensorflow/python/keras/distribute/worker_training_state.py index 6385594e0c0..114fd5d9692 100644 --- a/tensorflow/python/keras/distribute/worker_training_state.py +++ b/tensorflow/python/keras/distribute/worker_training_state.py @@ -73,15 +73,17 @@ class WorkerTrainingState(object): # workers need to perform `save()`. # But all workers should restore from the same checkpoint_dir as passed in # read_checkpoint_manager. - self.write_checkpoint_dir = distributed_file_utils.write_dirpath( + self.read_checkpoint_manager = checkpoint_management.CheckpointManager( + checkpoint, + directory=os.path.join(checkpoint_dir, 'chief'), + max_to_keep=1) + write_checkpoint_dir = distributed_file_utils.write_dirpath( checkpoint_dir, self._model.distribute_strategy) - self.write_checkpoint_manager = checkpoint_management.CheckpointManager( - checkpoint, directory=self.write_checkpoint_dir, max_to_keep=1) - if self.write_checkpoint_dir == checkpoint_dir: - self.read_checkpoint_manager = self.write_checkpoint_manager + if self._model.distribute_strategy.extended.should_checkpoint: + self.write_checkpoint_manager = self.read_checkpoint_manager else: - self.read_checkpoint_manager = checkpoint_management.CheckpointManager( - checkpoint, directory=checkpoint_dir, max_to_keep=1) + self.write_checkpoint_manager = checkpoint_management.CheckpointManager( + checkpoint, directory=write_checkpoint_dir, max_to_keep=1) def back_up(self, epoch): """Back up the current state of training into a checkpoint file. @@ -111,13 +113,8 @@ class WorkerTrainingState(object): Delete the backup directories which should not exist after `fit()` successfully finishes. """ - # pylint: disable=protected-access - for pathname in file_io.get_matching_files_v2( - self.write_checkpoint_manager._prefix + '*'): - file_io.delete_recursively_v2(pathname) - for pathname in file_io.get_matching_files_v2( - os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')): - file_io.delete_recursively_v2(pathname) + if self.write_checkpoint_manager is self.read_checkpoint_manager: + file_io.delete_recursively_v2(self.write_checkpoint_manager.directory) def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): """Maybe load initial epoch from ckpt considering possible worker recovery. diff --git a/tensorflow/python/keras/distribute/worker_training_state_test.py b/tensorflow/python/keras/distribute/worker_training_state_test.py index cb65747c239..67f15b9748a 100644 --- a/tensorflow/python/keras/distribute/worker_training_state_test.py +++ b/tensorflow/python/keras/distribute/worker_training_state_test.py @@ -24,7 +24,6 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.framework import test_combinations as combinations -from tensorflow.python.framework.errors_impl import NotFoundError from tensorflow.python.keras import callbacks from tensorflow.python.keras.distribute import multi_worker_testing_utils from tensorflow.python.lib.io import file_io @@ -51,13 +50,8 @@ class ModelCheckpointTest(test_base.IndependentWorkerTestBase, filepath=saving_filepath, save_weights_only=save_weights_only) ] self.assertFalse(file_io.file_exists_v2(saving_filepath)) - - try: - model.fit( - x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) - except NotFoundError as e: - if 'Failed to create a NewWriteableFile' in e.message: - self.skipTest('b/138941852, path not found error in Windows py35.') + model.fit( + x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) tf_saved_model_exists = file_io.file_exists_v2(saving_filepath) tf_weights_only_checkpoint_exists = file_io.file_exists_v2( saving_filepath + '.index') diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index ca6803101d5..3e3db32fee3 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -67,9 +67,9 @@ py_library( "//tensorflow/python/keras:optimizers", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras/distribute", - "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", - "//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/keras/mixed_precision:autocast_variable", + "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer", + "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/keras/saving", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:metrics_utils", @@ -152,9 +152,9 @@ py_library( # TODO(keras-team): Fix the cyclar deps between layer and metrics. # "//tensorflow/python/keras:metrics", "//tensorflow/python/keras:regularizers", - "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", - "//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/keras/mixed_precision:autocast_variable", + "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer", + "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/keras/saving", "//tensorflow/python/keras/utils:generic_utils", "//tensorflow/python/keras/utils:layer_utils", @@ -340,7 +340,6 @@ tf_py_test( tags = [ "nomac", # TODO(mihaimaruseac): b/127695564 ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -356,7 +355,6 @@ tf_py_test( tags = [ "nomac", # TODO(mihaimaruseac): b/127695564 ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -372,7 +370,6 @@ tf_py_test( tags = [ "nomac", # TODO(mihaimaruseac): b/127695564 ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -409,6 +406,7 @@ tf_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", + "//tensorflow/python/keras", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:callbacks", "//tensorflow/python/keras:combinations", @@ -640,7 +638,6 @@ tf_py_test( tags = [ "nomac", # TODO(mihaimaruseac): b/127695564 ], - tfrt_enabled = True, deps = [ ":base_layer", ":engine", @@ -665,6 +662,7 @@ tf_py_test( ":engine", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:composite_tensor", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", @@ -677,6 +675,7 @@ tf_py_test( "//tensorflow/python:summary_ops_v2", "//tensorflow/python:tensor_array_ops", "//tensorflow/python:tensor_spec", + "//tensorflow/python:type_spec", "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/eager:context", @@ -688,7 +687,7 @@ tf_py_test( "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/layers", "//tensorflow/python/keras/legacy_tf_layers:core", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/ops/ragged:ragged_tensor", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index bde0a63b212..558d6a71732 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -56,9 +56,9 @@ from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.engine import node as node_module -from tensorflow.python.keras.mixed_precision.experimental import autocast_variable -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import autocast_variable +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import layer_serialization from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils @@ -124,9 +124,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Arguments: trainable: Boolean, whether the layer's variables should be trainable. name: String name of the layer. - dtype: The dtype of the layer's computations and weights (default of - `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type - of the first input in TensorFlow 1). + dtype: The dtype of the layer's computations and weights. Can also be a + `tf.keras.mixed_precision.Policy`, which allows the computation and weight + dtype to differ. Default of `None` means to use + `tf.keras.mixed_precision.global_policy()`, which is a float32 policy + unless set to different value. dynamic: Set this to `True` if your layer should only be run eagerly, and should not be used to generate a static computation graph. This would be the case for a Tree-RNN or a recursive network, @@ -137,9 +139,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Attributes: name: The name of the layer (string). dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. + precision is used with a `tf.keras.mixed_precision.Policy`, this is + instead just the dtype of the layer's weights, as the computations are + done in a different dtype. trainable_weights: List of variables to be included in backprop. non_trainable_weights: List of variables that should not be included in backprop. @@ -269,18 +271,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): For more information about creating layers, see the guide [Writing custom layers and models with Keras]( https://www.tensorflow.org/guide/keras/custom_layers_and_models) - - About the layer's `dtype` attribute: - - Each layer has a dtype, which is typically the dtype of the layer's - computations and variables. A layer's dtype can be queried via the - `Layer.dtype` property. The dtype is specified with the `dtype` constructor - argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()` - if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, - layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed - precision is used, layers may have different computation and variable dtypes. - See `tf.keras.mixed_precision.experimental.Policy` for details on layer - dtypes. """ # See tf.Module for the usage of this property. @@ -304,6 +294,20 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # not available to the restoration code). _must_restore_from_config = False + def _instrument_layer_creation(self): + self._instrumented_keras_api = False + self._instrumented_keras_layer_class = False + self._instrumented_keras_model_class = False + if not getattr(self, '_disable_keras_instrumentation', False): + keras_api_gauge.get_cell('layer').set(True) + self._instrumented_keras_api = True + if getattr(self, '_is_model_for_instrumentation', False): + keras_models_gauge.get_cell(self.__class__.__name__).set(True) + self._instrumented_keras_model_class = True + else: + keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrumented_keras_layer_class = True + @trackable.no_automatic_dependency_tracking def __init__(self, trainable=True, @@ -311,11 +315,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): dtype=None, dynamic=False, **kwargs): - keras_api_gauge.get_cell('layer').set(True) - if getattr(self, '_is_model_for_instrumentation', False): - keras_models_gauge.get_cell(self.__class__.__name__).set(True) - else: - keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrument_layer_creation() + # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords @@ -377,9 +378,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._metrics_lock = threading.Lock() # Both graph and subclassed networks have a dtype policy. For graph - # networks, the policy's compute and variable dtypes are ignored, but other - # fields, like the loss scale, are used by Models. For subclassed networks, - # the compute and variable dtypes are used as like any ordinary layer. + # networks, the policy's compute and variable dtypes are ignored. Such + # networks only use the policy if it is a PolicyV1, in which case it uses + # the PolicyV1's loss_scale (Policy does not have a loss_scale). For + # subclassed networks, the compute and variable dtypes are used as like any + # ordinary layer. self._set_dtype_policy(dtype) # Boolean indicating whether the layer automatically casts its inputs to the # layer's compute_dtype. @@ -596,8 +599,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ' for layer %s' % (name, dtype.base_dtype, self.name)) getter = kwargs.pop('getter', base_layer_utils.make_variable) - if (autocast and self._dtype_policy.should_cast_variables and - dtype.is_floating): + if (autocast and + self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype + and dtype.is_floating): old_getter = getter # Wrap variable constructor to return an AutoCastVariable. def getter(*args, **kwargs): # pylint: disable=function-redefined @@ -1232,7 +1236,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector): @property def dtype(self): - """Dtype used by the weights of the layer, set in the constructor.""" + """The dtype of the layer weights. + + This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless + mixed precision is used, this is the same as `Layer.compute_dtype`, the + dtype of the layer's computations. + """ return self._dtype_policy.variable_dtype @property @@ -2348,20 +2357,41 @@ class Layer(module.Module, version_utils.LayerVersionSelector): else: self._compute_dtype_object = None - # TODO(reedwm): Expose this property? @property - def _compute_dtype(self): - """The layer's compute dtype. + def dtype_policy(self): + """The dtype policy associated with this layer. - Unless mixed-precision is used, this is the same as `Layer.dtype`. + This is an instance of a `tf.keras.mixed_precision.Policy`. + """ + return self._dtype_policy - If self._autocast is True, layer's will cast floating-point inputs to this. + @property + def compute_dtype(self): + """The dtype of the layer's computations. + + This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless + mixed precision is used, this is the same as `Layer.dtype`, the dtype of + the weights. + + Layers often perform certain internal computations in higher precision when + `compute_dtype` is float16 or bfloat16 for numeric stability. The output + will still typically be float16 or bfloat16 in such cases. Returns: The layer's compute dtype. """ return self._dtype_policy.compute_dtype + @property + def _compute_dtype(self): + """Deprecated alias of `compute_dtype`.""" + return self._dtype_policy.compute_dtype + + @property + def variable_dtype(self): + """Alias of `Layer.dtype`, the dtype of the weights.""" + return self.dtype + def _maybe_cast_inputs(self, inputs, input_list=None): """Maybe casts the inputs to the compute dtype. @@ -2831,7 +2861,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Append value to list of trainable / non-trainable weights if relevant # TODO(b/125122625): This won't pick up on any variables added to a # list/dict after creation. - for val in nest.flatten(value): + for val in nest.flatten(value, expand_composites=True): # TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops # no longer return True for isinstance Variable checks. if not isinstance(val, tf_variables.Variable): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index f5cb63226f3..19ccb92a554 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -27,12 +27,14 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized @@ -83,6 +85,13 @@ class InvalidLayer(base_layer.Layer): class BaseLayerTest(keras_parameterized.TestCase): + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_layer_instrumentation(self): + layer = layers.Add() + self.assertTrue(layer._instrumented_keras_api) + self.assertTrue(layer._instrumented_keras_layer_class) + self.assertFalse(layer._instrumented_keras_model_class) + @combinations.generate(combinations.times( combinations.keras_model_type_combinations(), combinations.keras_tensor_combinations())) @@ -433,6 +442,49 @@ class BaseLayerTest(keras_parameterized.TestCase): # Checks that variables get initialized. model.fit(x, y, batch_size=2, epochs=2) + @combinations.generate(combinations.combine(mode=['eager'])) + def test_composite_variable_assignment(self): + + class Spec(type_spec.TypeSpec): + + value_type = property(lambda self: CompositeVariable) + + def _component_specs(self): + pass + + def _serialize(self): + pass + + def _to_components(self, value): + return value._variables + + def _from_components(self, variable_list): + return CompositeVariable(variable_list) + + class CompositeVariable(composite_tensor.CompositeTensor): + + def __init__(self, variable_list): + self._variables = variable_list + + @property + def _type_spec(self): + return Spec() + + class CompositeVariableLayer(base_layer.Layer): + + def __init__(self): + super().__init__() + self.composite_var = CompositeVariable( + [variables.Variable(1.), + variables.Variable(2.)]) + + layer = CompositeVariableLayer() + self.assertLen(layer.weights, 2) + self.assertIsInstance(layer.weights[0], variables.Variable) + self.assertIsInstance(layer.weights[1], variables.Variable) + self.assertEqual(self.evaluate(layer.weights[0]), 1.) + self.assertEqual(self.evaluate(layer.weights[1]), 2.) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_layer_names(self): with testing_utils.use_keras_tensors_scope(False): @@ -1450,7 +1502,7 @@ class IdentityLayer(base_layer.Layer): class DTypeTest(keras_parameterized.TestCase): # This class only have tests relating to layer.dtype. Tests for dtype policies - # are in mixed_precision/experimental/keras_test.py + # are in mixed_precision/keras_test.py # TODO(reedwm): Maybe have a separate test file for input casting tests. diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index d6b32907593..1acf289c4ca 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -33,10 +33,10 @@ from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.training.tracking import base as tracking +from tensorflow.python.util import keras_deps from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export @@ -417,7 +417,7 @@ def call_context(): return call_ctx -control_flow_util_v2._register_keras_layer_context_function(call_context) # pylint: disable=protected-access +keras_deps.register_call_context_function(call_context) class CallContext(object): @@ -741,9 +741,9 @@ def enable_v2_dtype_behavior(): autocasting part of the V2 behavior for that layer, but not the defaulting to floatx part of the V2 behavior. - When a global `tf.keras.mixed_precision.experimental.Policy` is set, a Keras - layer's dtype will default to the global policy instead of floatx. Layers - will automatically cast inputs to the policy's compute_dtype. + When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype + will default to the global policy instead of floatx. Layers will automatically + cast inputs to the policy's compute_dtype. """ global V2_DTYPE_BEHAVIOR V2_DTYPE_BEHAVIOR = True diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index fd9db0e4346..17c7faf37ff 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -46,9 +46,9 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.mixed_precision.experimental import autocast_variable -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import autocast_variable +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import layer_serialization from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils @@ -112,9 +112,9 @@ class Layer(base_layer.Layer): Attributes: name: The name of the layer (string). dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. + precision is used with a `tf.keras.mixed_precision.Policy`, this is + instead just the dtype of the layer's weights, as the computations are + done in a different dtype. updates: List of update ops of this layer. losses: List of losses added by this layer. trainable_weights: List of variables to be included in backprop. @@ -133,8 +133,7 @@ class Layer(base_layer.Layer): if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed precision is used, layers may have different computation and variable dtypes. - See `tf.keras.mixed_precision.experimental.Policy` for details on layer - dtypes. + See `tf.keras.mixed_precision.Policy` for details on layer dtypes. """ # See tf.Module for the usage of this property. @@ -152,8 +151,8 @@ class Layer(base_layer.Layer): @trackable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs): - base_layer.keras_api_gauge.get_cell('layer').set(True) - base_layer.keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + self._instrument_layer_creation() + # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords @@ -199,9 +198,11 @@ class Layer(base_layer.Layer): self._metrics = [] # Both graph and subclassed networks have a dtype policy. For graph - # networks, the policy's compute and variable dtypes are ignored, but other - # fields, like the loss scale, are used by Models. For subclassed networks, - # the compute and variable dtypes are used as like any ordinary layer. + # networks, the policy's compute and variable dtypes are ignored. Such + # networks only use the policy if it is a PolicyV1, in which case it uses + # the PolicyV1's loss_scale (Policy does not have a loss_scale). For + # subclassed networks, the compute and variable dtypes are used as like any + # ordinary layer. self._set_dtype_policy(dtype) # Boolean indicating whether the layer automatically casts its inputs to the # layer's compute_dtype. @@ -420,8 +421,9 @@ class Layer(base_layer.Layer): raise ValueError('An initializer for variable %s of type %s is required' ' for layer %s' % (name, dtype.base_dtype, self.name)) - if (autocast and self._dtype_policy.should_cast_variables and - dtype.is_floating): + if (autocast and + self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype + and dtype.is_floating): # Wrap 'getter' with a version that returns an AutoCastVariable. old_getter = getter def getter(*args, **kwargs): # pylint: disable=function-redefined @@ -2102,9 +2104,10 @@ class Layer(base_layer.Layer): # operations. with tf_utils.maybe_init_scope(self): self.build(input_shapes) - # We must set self.built since user defined build functions are not - # constrained to set self.built. - self.built = True + # We must set also ensure that the layer is marked as built, and the build + # shape is stored since user defined build functions may not be calling + # `super.build()` + Layer.build(self, input_shapes) # Optionally load weight values specified at layer instantiation. if self._initial_weights is not None: diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 2cc6f69403e..6afe1840458 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -31,6 +31,7 @@ import six from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import input_lib from tensorflow.python.eager import context @@ -523,9 +524,11 @@ class CompositeTensorDataAdapter(DataAdapter): flat_inputs += nest.flatten(y) def _is_composite(v): - # Dataset inherits from CompositeTensor but shouldn't be handled here. + # Dataset/iterator inherits from CompositeTensor but should be handled + # by DatasetAdapter and GeneratorAdapter. if (tf_utils.is_extension_type(v) and - not isinstance(v, dataset_ops.DatasetV2)): + not isinstance(v, (dataset_ops.DatasetV2, + iterator_ops.IteratorBase))): return True # Support Scipy sparse tensors if scipy is installed if scipy_sparse is not None and scipy_sparse.issparse(v): diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 9ca63ec42f0..59613439bf9 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -953,6 +953,25 @@ class DataHandlerTest(keras_parameterized.TestCase): self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], [([0],), ([1],), ([2],)]]) + def test_iterator(self): + def generator(): + for _ in range(2): + for step in range(3): + yield (ops.convert_to_tensor_v2_with_dispatch([step]),) + + it = iter(dataset_ops.Dataset.from_generator( + generator, output_types=('float32',))) + data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3) + returned_data = [] + for _, iterator in data_handler.enumerate_epochs(): + epoch_data = [] + for _ in data_handler.steps(): + epoch_data.append(next(iterator)) + returned_data.append(epoch_data) + returned_data = self.evaluate(returned_data) + self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], + [([0],), ([1],), ([2],)]]) + def test_list_of_scalars(self): data_handler = data_adapter.DataHandler([[0], [1], [2]], epochs=2, diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 892773fa656..57ae9ee92a7 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -869,6 +869,13 @@ class Functional(training_lib.Model): def _trackable_saved_model_saver(self): return network_serialization.NetworkSavedModelSaver(self) + def _get_save_spec(self, dynamic_batch=True): + if getattr(self, '_has_explicit_input_shape', True): + # Functional models and Sequential models that have an explicit input + # shape should use the batch size set by the input layer. + dynamic_batch = False + return super(Functional, self)._get_save_spec(dynamic_batch) + def _make_node_key(layer_name, node_index): return layer_name + '_ib-' + str(node_index) diff --git a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py index 92abdc82240..fc85fef29bf 100644 --- a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py +++ b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py @@ -20,15 +20,20 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import training from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test +from tensorflow.python.util import nest class RaggedKerasTensorTest(keras_parameterized.TestCase): @@ -89,6 +94,278 @@ class RaggedKerasTensorTest(keras_parameterized.TestCase): x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) self.assertAllEqual(model(x), x / x) + @parameterized.parameters( + {'property_name': 'values'}, + {'property_name': 'flat_values'}, + {'property_name': 'row_splits'}, + {'property_name': 'nested_row_splits'}, + ) + def test_instance_property(self, property_name): + inp = layers.Input(shape=[None], ragged=True) + out = getattr(inp, property_name) + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + expected_property = getattr(x, property_name) + self.assertAllEqual(model(x), expected_property) + + # Test that it works with serialization and deserialization as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected_property) + + @parameterized.parameters( + {'name': 'value_rowids'}, + {'name': 'nested_value_rowids'}, + {'name': 'nrows'}, + {'name': 'row_starts'}, + {'name': 'row_limits'}, + {'name': 'row_lengths'}, + {'name': 'nested_row_lengths'}, + {'name': 'bounding_shape'}, + { + 'name': 'with_values', + 'args': [[1, 2, 3, 4, 5, 6]] + }, + { + 'name': 'with_flat_values', + 'kwargs': { + 'new_values': [1, 2, 3, 4, 5, 6] + } + }, + { + 'name': 'with_row_splits_dtype', + 'kwargs': { + 'dtype': dtypes.int32 + } + }, + { + 'name': 'merge_dims', + 'args': [0], + 'kwargs': { + 'inner_axis': 1 + } + }, + {'name': 'to_tensor'}, + {'name': 'to_sparse'}, + ) + def test_instance_method(self, name, args=None, kwargs=None): + if not args: + args = [] + if not kwargs: + kwargs = {} + + inp = layers.Input(shape=[None], ragged=True) + out = getattr(inp, name)(*args, **kwargs) + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + expected_property = getattr(x, name)(*args, **kwargs) + # We expand composites before checking equality because + # assertAllEqual otherwise wouldn't work for SparseTensor outputs + for a, b in zip(nest.flatten(model(x), expand_composites=True), + nest.flatten(expected_property, expand_composites=True)): + self.assertAllEqual(a, b) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + for a, b in zip(nest.flatten(model2(x), expand_composites=True), + nest.flatten(expected_property, expand_composites=True)): + self.assertAllEqual(a, b) + + +class RaggedTensorClassMethodAsLayerTest(keras_parameterized.TestCase): + + def test_from_value_rowids(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_value_rowids( + inp, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_value_rowids( + x, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_splits(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_splits( + inp, row_splits=[0, 4, 4, 7, 8, 8]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_splits( + x, row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_lengths(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_lengths( + inp, row_lengths=[4, 0, 3, 1, 0]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_lengths( + x, row_lengths=[4, 0, 3, 1, 0]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_starts(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_starts( + inp, row_starts=[0, 4, 4, 7, 8]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_starts( + x, row_starts=[0, 4, 4, 7, 8]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_limits(self): + row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) + + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_row_limits( + inp, row_limits, validate=False) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_row_limits( + x, row_limits, validate=False) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_uniform_row_length(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_uniform_row_length(inp, 2, 8) + model = training.Model(inp, out) + + x = constant_op.constant( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) + expected = ragged_tensor.RaggedTensor.from_uniform_row_length(x, 2, 8) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_value_row_ids(self): + nested_value_rowids = [ + constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), + constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_value_rowids( + inp, nested_value_rowids) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_value_rowids( + x, nested_value_rowids) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_row_splits(self): + nested_row_splits = [ + constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), + constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_row_splits( + inp, nested_row_splits) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_row_splits( + x, nested_row_splits) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_row_lengths(self): + nested_row_lengths = [ + constant_op.constant([2, 1, 0, 2], dtypes.int64), + constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_row_lengths( + inp, nested_row_lengths) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_row_lengths( + x, nested_row_lengths) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_tensor(self): + inp = layers.Input(shape=[None], ragged=False) + out = ragged_tensor.RaggedTensor.from_tensor(inp) + model = training.Model(inp, out) + + x = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]]) + expected = ragged_tensor.RaggedTensor.from_tensor(x) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_sparse(self): + inp = layers.Input(shape=[None], sparse=True, dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_sparse(inp) + model = training.Model(inp, out) + + indices = [[0, 0], [1, 0], [1, 1], [2, 0]] + values = [b'a', b'b', b'c', b'd'] + shape = [4, 5] + sp_value = sparse_tensor.SparseTensor(indices, values, shape) + + expected = ragged_tensor.RaggedTensor.from_sparse(sp_value) + self.assertAllEqual(model(sp_value), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(sp_value), expected) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d0053599751..e0d53deebee 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -51,13 +51,15 @@ from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite @@ -548,12 +550,25 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): def _get_optimizer(self, optimizer): """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" + # The deprecated PolicyV1 has a loss_scale, which we use for backwards + # compatibility to match TF 2.3 behavior. The new Policy does not have a + # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is + # used. + if isinstance(self._dtype_policy, policy.PolicyV1): + loss_scale = self._dtype_policy.loss_scale + elif self._dtype_policy.name == 'mixed_float16': + loss_scale = 'dynamic' + else: + loss_scale = None def _get_single_optimizer(opt): opt = optimizers.get(opt) - if (self._dtype_policy.loss_scale is not None and + if (loss_scale is not None and not isinstance(opt, lso.LossScaleOptimizer)): - opt = lso.LossScaleOptimizer(opt, self._dtype_policy.loss_scale) + if loss_scale == 'dynamic': + opt = lso.LossScaleOptimizer(opt) + else: + opt = lso.LossScaleOptimizerV1(opt, loss_scale) return opt return nest.map_structure(_get_single_optimizer, optimizer) @@ -1099,6 +1114,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): if validation_data and self._should_eval(epoch, validation_freq): # Create data_handler for evaluation and cache it. if getattr(self, '_eval_data_handler', None) is None: + self._fit_frame = tf_inspect.currentframe() self._eval_data_handler = data_adapter.DataHandler( x=val_x, y=val_y, @@ -1134,6 +1150,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): # If eval data_hanlder exists, delete it after all epochs are done. if getattr(self, '_eval_data_handler', None) is not None: del self._eval_data_handler + del self._fit_frame callbacks.on_train_end(logs=training_logs) return self.history @@ -1327,7 +1344,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): _disallow_inside_tf_function('evaluate') with self.distribute_strategy.scope(): - if getattr(self, '_eval_data_handler', None) is not None: + # Use cached evaluation data only when it's called in `Model.fit` + if (getattr(self, '_fit_frame', None) is not None + and tf_inspect.currentframe().f_back is self._fit_frame + and getattr(self, '_eval_data_handler', None) is not None): data_handler = self._eval_data_handler else: # Creates a `tf.data.Dataset` and handles batch and epoch iteration. @@ -1933,31 +1953,14 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): include_optimizer=True, save_format=None, signatures=None, - options=None): + options=None, + save_traces=True): + # pylint: disable=line-too-long """Saves the model to Tensorflow SavedModel or a single HDF5 file. - The savefile includes: - - - The model architecture, allowing to re-instantiate the model. - - The model weights. - - The state of the optimizer, allowing to resume training - exactly where you left off. - - This allows you to save the entirety of the state of a model - in a single file. - - Saved models can be re-instantiated via `keras.models.load_model`. - The model returned by `load_model` is a compiled model ready to be used - (unless the saved model was never compiled in the first place). - - Models built with the Sequential and Functional API can be saved to both the - HDF5 and SavedModel formats. Subclassed models can only be saved with the - SavedModel format. - - Note that the model weights may have different scoped names after being - loaded. Scoped names include the model/layer names, such as - `"dense_1/kernel:0"`. It is recommended that you use the layer properties to - access specific variables, e.g. `model.get_layer("dense_1").kernel`. + Please see `tf.keras.models.save_model` or the + [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) + for details. Arguments: filepath: String, PathLike, path to SavedModel or H5 file to save the @@ -1971,8 +1974,15 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): signatures: Signatures to save with the SavedModel. Applicable to the 'tf' format only. Please see the `signatures` argument in `tf.saved_model.save` for details. - options: Optional `tf.saved_model.SaveOptions` object that specifies - options for saving to SavedModel. + options: (only applies to SavedModel format) + `tf.saved_model.SaveOptions` object that specifies options for + saving to SavedModel. + save_traces: (only applies to SavedModel format) When enabled, the + SavedModel will store the function traces for each layer. This + can be disabled, so that only the configs of each layer are stored. + Defaults to `True`. Disabling this will decrease serialization time + and reduce file size, but it requires that all custom layers/models + implement a `get_config()` method. Example: @@ -1987,8 +1997,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): model = load_model('my_model.h5') ``` """ + # pylint: enable=line-too-long save.save_model(self, filepath, overwrite, include_optimizer, save_format, - signatures, options) + signatures, options, save_traces) def save_weights(self, filepath, @@ -2735,20 +2746,24 @@ def _collective_all_reduce_multi_worker(strategy): # for all strategies def _multi_worker_concat(v, strategy): """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" - replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access - # v might not have the same shape on different replicas - if isinstance(v, ds_values.PerReplica): - shapes = array_ops.concat([ - array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) - for single_value in v.values - ], - axis=0) - all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access - else: - # v is a tensor. This may happen when, say, we have 2x1 multi-worker. - all_shapes = strategy._gather( # pylint: disable=protected-access - array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), - axis=0) + replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access + # TODO(b/170435030): We now need to make sure these run after the iterator + # GetNext, so that we don't trigger aborting collective ops in the case of + # EOF. Remove after the issue is fixed. + with ops.control_dependencies([replicas]): + # v might not have the same shape on different replicas + if isinstance(v, ds_values.PerReplica): + shapes = array_ops.concat([ + array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) + for single_value in v.values + ], + axis=0) + all_shapes = strategy.gather(shapes, axis=0) + else: + # v is a tensor. This may happen when, say, we have 2x1 multi-worker. + all_shapes = strategy.gather( + array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), + axis=0) replicas = array_ops.split( replicas, diff --git a/tensorflow/python/keras/engine/training_eager_v1.py b/tensorflow/python/keras/engine/training_eager_v1.py index 2acd7493cb0..a52b20c5aa0 100644 --- a/tensorflow/python/keras/engine/training_eager_v1.py +++ b/tensorflow/python/keras/engine/training_eager_v1.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils_v1 -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer from tensorflow.python.keras.utils import losses_utils from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 6a833560cff..dee1055bbc4 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -68,6 +68,19 @@ except ImportError: class TrainingTest(keras_parameterized.TestCase): + @keras_parameterized.run_all_keras_modes + @keras_parameterized.run_with_all_model_types + def test_model_instrumentation(self): + layers = [ + layers_module.Dense(10, dtype=np.float64), + layers_module.Dense(10, dtype=np.float64) + ] + model = testing_utils.get_model_from_layers(layers, input_shape=(1,)) + + self.assertTrue(model._instrumented_keras_api) + self.assertTrue(model._instrumented_keras_model_class) + self.assertFalse(model._instrumented_keras_layer_class) + @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes def test_fit_training_arg(self): @@ -2445,7 +2458,7 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): output_a_np = np.random.random((10, 4)) output_b_np = np.random.random((10, 3)) - input_v = backend.variables_module.Variable(input_a_np, dtype='float32') + input_v = variables_lib.Variable(input_a_np, dtype='float32') self.evaluate(variables_lib.variables_initializer([input_v])) a = input_layer.Input(tensor=input_v) b = input_layer.Input(shape=(3,), name='input_b') @@ -2656,7 +2669,7 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): out = model.evaluate(input_a_np, None) # Test model with no external data at all. - input_v = backend.variables_module.Variable(input_a_np, dtype='float32') + input_v = variables_lib.Variable(input_a_np, dtype='float32') self.evaluate(variables_lib.variables_initializer([input_v])) a = input_layer.Input(tensor=input_v) a_2 = layers_module.Dense(4, name='dense_1')(a) @@ -2815,7 +2828,7 @@ class TestTrainingWithDataTensors(keras_parameterized.TestCase): }) # test with custom TF placeholder as target - pl_target_a = backend.array_ops.placeholder('float32', shape=(None, 4)) + pl_target_a = array_ops.placeholder('float32', shape=(None, 4)) model.compile(optimizer='rmsprop', loss='mse', target_tensors={'dense_1': pl_target_a}) model.train_on_batch([input_a_np, input_b_np], @@ -3642,16 +3655,20 @@ class TestAutoUpdates(keras_parameterized.TestCase): class TestFunctionTracing(keras_parameterized.TestCase): + def _seq_model_and_data(self): + model = sequential.Sequential([layers_module.Dense(4, activation='relu')]) + model.compile(loss='mse', optimizer='rmsprop') + x = np.random.random((10, 6)) + y = np.random.random((10, 4)) + return model, x, y + @keras_parameterized.run_all_keras_modes( always_skip_v1=True, always_skip_eager=True) def test_no_tracing_between_epoch(self): if sys.version_info[0] < 3: self.skipTest('self.assertLogs() call is not available in Python 2.') - model = sequential.Sequential([layers_module.Dense(4, activation='relu')]) - model.compile(loss='mse', optimizer='rmsprop') - x = np.random.random((10, 6)) - y = np.random.random((10, 4)) + model, x, y = self._seq_model_and_data() logging.set_verbosity(1) with self.assertLogs(level=1) as logs: @@ -3660,6 +3677,21 @@ class TestFunctionTracing(keras_parameterized.TestCase): new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function' self.assertEqual(sum(new_func_graph in log for log in logs.output), 9) + @keras_parameterized.run_all_keras_modes( + always_skip_v1=True, always_skip_eager=True) + def test_evaluate_no_cached_data(self): + if sys.version_info[0] < 3: + self.skipTest('self.assertLogs() call is not available in Python 2.') + + model, x, y = self._seq_model_and_data() + + new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function' + logging.set_verbosity(1) + with self.assertLogs(level=1) as eval_logs: + for _ in range(6): + model.evaluate(x, y, batch_size=5) + self.assertEqual(sum(new_func_graph in log for log in eval_logs.output), 20) + class TestBuildCustomModel(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 54969bb5e83..7d4eb1325df 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -52,7 +52,8 @@ from tensorflow.python.keras.engine import training_eager_v1 from tensorflow.python.keras.engine import training_generator_v1 from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils_v1 -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.utils import data_utils @@ -359,9 +360,9 @@ class Model(training_lib.Model): distribution_strategy_context.get_strategy()) if isinstance(self._distribution_strategy, - (parameter_server_strategy.ParameterServerStrategyV1, - parameter_server_strategy.ParameterServerStrategy)): - raise NotImplementedError('ParameterServerStrategy currently only works ' + parameter_server_strategy.ParameterServerStrategyV1): + raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet' + 'erServerStrategy` currently only works ' 'with the tf.Estimator API') if not self._experimental_run_tf_function: @@ -1342,7 +1343,14 @@ class Model(training_lib.Model): else: self.optimizer = optimizers.get(optimizer) - if (self._dtype_policy.loss_scale is not None and + if isinstance(self._dtype_policy, policy.PolicyV1): + loss_scale = self._dtype_policy.loss_scale + elif self._dtype_policy.name == 'mixed_float16': + loss_scale = 'dynamic' + else: + loss_scale = None + + if (loss_scale is not None and not isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer)): if isinstance(self.optimizer, list): @@ -1356,18 +1364,11 @@ class Model(training_lib.Model): 'with a loss scale used, but got: %s. Using policy: ' '%s' % (self.optimizer, self._dtype_policy)) - self.optimizer = loss_scale_optimizer.LossScaleOptimizer( - self.optimizer, self._dtype_policy.loss_scale) - if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and - self._dtype_policy.loss_scale and - self.optimizer.loss_scale != self._dtype_policy.loss_scale): - logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) ' - 'is not the same as the dtype policy\'s loss scale (%s). ' - 'Because the dtype policy has a loss scale, you should ' - 'pass an optimizer that is not wrapped with a ' - 'LossScaleOptimizer,' - % (self.optimizer.loss_scale, - self._dtype_policy.loss_scale)) + if loss_scale == 'dynamic': + self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer) + else: + self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1( + self.optimizer, loss_scale) def _prepare_validation_data(self, validation_data, batch_size, validation_steps): diff --git a/tensorflow/python/keras/initializers/initializers_v2.py b/tensorflow/python/keras/initializers/initializers_v2.py index 66e6719f31f..0e4fd66027e 100644 --- a/tensorflow/python/keras/initializers/initializers_v2.py +++ b/tensorflow/python/keras/initializers/initializers_v2.py @@ -34,7 +34,7 @@ class Initializer(object): signature: ```python - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): # returns a tensor of shape `shape` and dtype `dtype` # containing values drawn from a distribution of your choice. ``` @@ -54,7 +54,7 @@ class Initializer(object): self.mean = mean self.stddev = stddev - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): return tf.random.normal( shape, mean=self.mean, stddev=self.stddev, dtype=dtype) @@ -68,12 +68,13 @@ class Initializer(object): works fine. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: shape: Shape of the tensor. dtype: Optional dtype of the tensor. + **kwargs: Additional keyword arguments. """ raise NotImplementedError @@ -124,7 +125,7 @@ class Zeros(init_ops_v2.Zeros, Initializer): >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: @@ -133,8 +134,9 @@ class Zeros(init_ops_v2.Zeros, Initializer): supported. If not specified, `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`). + **kwargs: Additional keyword arguments. """ - return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[]) @@ -154,7 +156,7 @@ class Ones(init_ops_v2.Ones, Initializer): >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: @@ -163,8 +165,9 @@ class Ones(init_ops_v2.Ones, Initializer): supported. If not specified, `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`). + **kwargs: Additional keyword arguments. """ - return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.Constant', @@ -196,7 +199,7 @@ class Constant(Initializer): def __init__(self, value=0): self.value = value - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized to `self.value`. Args: @@ -205,7 +208,9 @@ class Constant(Initializer): `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`). + **kwargs: Additional keyword arguments. """ + del kwargs return constant_op.constant( self.value, dtype=_get_dtype(dtype), shape=shape) @@ -241,7 +246,7 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer): always produce the same random tensor for a given shape and dtype. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: @@ -251,8 +256,10 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer): `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`). + **kwargs: Additional keyword arguments. """ - return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(RandomUniform, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.RandomNormal', @@ -283,17 +290,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer): always produce the same random tensor for a given shape and dtype. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized to random normal values. Args: shape: Shape of the tensor. dtype: Optional dtype of the tensor. Only floating point types are - supported. If not specified, `tf.keras.backend.floatx()` is used, - which default to `float32` unless you configured it otherwise - (via `tf.keras.backend.set_floatx(float_dtype)`) + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + **kwargs: Additional keyword arguments. """ - return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(RandomNormal, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.TruncatedNormal', @@ -329,17 +338,19 @@ class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer): always produce the same random tensor for a given shape and dtype. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized to random normal values (truncated). Args: shape: Shape of the tensor. dtype: Optional dtype of the tensor. Only floating point types are - supported. If not specified, `tf.keras.backend.floatx()` is used, - which default to `float32` unless you configured it otherwise - (via `tf.keras.backend.set_floatx(float_dtype)`) + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + **kwargs: Additional keyword arguments. """ - return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(TruncatedNormal, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.VarianceScaling', @@ -384,17 +395,19 @@ class VarianceScaling(init_ops_v2.VarianceScaling, Initializer): always produce the same random tensor for a given shape and dtype. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: shape: Shape of the tensor. dtype: Optional dtype of the tensor. Only floating point types are - supported. If not specified, `tf.keras.backend.floatx()` is used, - which default to `float32` unless you configured it otherwise - (via `tf.keras.backend.set_floatx(float_dtype)`) + supported. If not specified, `tf.keras.backend.floatx()` is used, which + default to `float32` unless you configured it otherwise (via + `tf.keras.backend.set_floatx(float_dtype)`) + **kwargs: Additional keyword arguments. """ - return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(VarianceScaling, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.Orthogonal', @@ -436,7 +449,7 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer): ([pdf](https://arxiv.org/pdf/1312.6120.pdf)) """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized to an orthogonal matrix. Args: @@ -445,8 +458,10 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer): supported. If not specified, `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`) + **kwargs: Additional keyword arguments. """ - return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(Orthogonal, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.Identity', @@ -473,7 +488,7 @@ class Identity(init_ops_v2.Identity, Initializer): gain: Multiplicative factor to apply to the identity matrix. """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized to a 2D identity matrix. Args: @@ -482,8 +497,10 @@ class Identity(init_ops_v2.Identity, Initializer): supported. If not specified, `tf.keras.backend.floatx()` is used, which default to `float32` unless you configured it otherwise (via `tf.keras.backend.set_floatx(float_dtype)`) + **kwargs: Additional keyword arguments. """ - return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype)) + return super(Identity, self).__call__( + shape, dtype=_get_dtype(dtype), **kwargs) @keras_export('keras.initializers.GlorotUniform', diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py index f03de2b436e..47822ef0893 100644 --- a/tensorflow/python/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -253,6 +253,34 @@ class KerasInitializersTest(test.TestCase): initializer = initializers.deserialize(external_serialized_json) self.assertEqual(initializer.distribution, 'truncated_normal') + def test_partition(self): + with self.cached_session(): + partition_enabled_initializers = [ + initializers.ZerosV2(), + initializers.OnesV2(), + initializers.RandomUniformV2(), + initializers.RandomNormalV2(), + initializers.TruncatedNormalV2(), + initializers.LecunUniformV2(), + initializers.GlorotUniformV2(), + initializers.HeUniformV2() + ] + for initializer in partition_enabled_initializers: + got = initializer( + shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) + self.assertEqual(got.shape, (2, 2)) + + partition_forbidden_initializers = [ + initializers.OrthogonalV2(), + initializers.IdentityV2() + ] + for initializer in partition_forbidden_initializers: + with self.assertRaisesRegex( + ValueError, + "initializer doesn't support partition-related arguments"): + initializer( + shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py index cc0daa4cf70..24fcbf8fa4b 100644 --- a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py +++ b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py @@ -19,6 +19,9 @@ from __future__ import print_function import gc import tensorflow as tf + +from tensorflow.python.platform import test as test_lib + layers = tf.keras.layers optimizers = tf.keras.optimizers @@ -150,6 +153,10 @@ class GradientCheckpointTest(tf.test.TestCase): def test_does_not_raise_oom_exception(self): if not _limit_gpu_memory(): self.skipTest('No virtual GPUs found') + if test_lib.is_built_with_rocm(): + self.skipTest( + 'ROCm MIOpen does not support searching for memory-limited' + 'solvers yet so skip the subtest which would result in OOM.') n_step = 2 losses = _train_with_recompute(n_step) self.assertLen(losses, n_step) diff --git a/tensorflow/python/keras/keras_parameterized.py b/tensorflow/python/keras/keras_parameterized.py index a69452e247f..57c24b80b2f 100644 --- a/tensorflow/python/keras/keras_parameterized.py +++ b/tensorflow/python/keras/keras_parameterized.py @@ -113,7 +113,6 @@ def run_with_all_saved_model_formats( tf.test.main() ``` - Args: test_or_class: test method or class to be annotated. If None, this method returns a decorator that can be applied to a test method or @@ -134,7 +133,7 @@ def run_with_all_saved_model_formats( # Exclude h5 save format if H5py isn't available. if h5py is None: exclude_formats.append(['h5']) - saved_model_formats = ['h5', 'tf'] + saved_model_formats = ['h5', 'tf', 'tf_no_traces'] params = [('_%s' % saved_format, saved_format) for saved_format in saved_model_formats if saved_format not in nest.flatten(exclude_formats)] @@ -150,6 +149,8 @@ def run_with_all_saved_model_formats( _test_h5_saved_model_format(f, self, *args, **kwargs) elif saved_format == 'tf': _test_tf_saved_model_format(f, self, *args, **kwargs) + elif saved_format == 'tf_no_traces': + _test_tf_saved_model_format_no_traces(f, self, *args, **kwargs) else: raise ValueError('Unknown model type: %s' % (saved_format,)) return decorated @@ -167,6 +168,18 @@ def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs): f(test_or_class, *args, **kwargs) +def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs): + with testing_utils.saved_model_format_scope('tf', save_traces=False): + f(test_or_class, *args, **kwargs) + + +def run_with_all_weight_formats(test_or_class=None, exclude_formats=None): + """Runs all tests with the supported formats for saving weights.""" + exclude_formats = exclude_formats or [] + exclude_formats.append('tf_no_traces') # Only applies to saving models + return run_with_all_saved_model_formats(test_or_class, exclude_formats) + + # TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass # it. Or perhaps make 'subclass' always use a custom build method. def run_with_all_model_types( diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 4846df4f213..9c1d549da0f 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the Keras layers (internal TensorFlow version). -load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") package( # TODO(scottzhu): Remove non-keras deps from TF. @@ -590,7 +590,6 @@ tf_py_test( srcs = ["subclassed_layers_test.py"], python_version = "PY3", shard_count = 3, - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -777,7 +776,10 @@ tf_py_test( srcs = ["recurrent_test.py"], python_version = "PY3", shard_count = 12, - tags = ["no_rocm"], + tags = [ + "no_rocm", + "notsan", # TODO(b/170870794) + ], deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -805,7 +807,6 @@ cuda_py_test( size = "medium", srcs = ["separable_convolutional_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -822,6 +823,7 @@ cuda_py_test( shard_count = 12, tags = [ "no_oss", + "notsan", # TODO(b/170954246) ], deps = [ "//tensorflow/python:client_testlib", @@ -837,10 +839,6 @@ cuda_py_test( srcs = ["gru_v2_test.py"], python_version = "PY3", shard_count = 12, - tags = [ - "no_cuda11", - "no_oss", - ], deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -855,7 +853,6 @@ tf_py_test( size = "small", srcs = ["serialization_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -922,7 +919,6 @@ tf_py_test( tags = [ "notsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -939,7 +935,6 @@ tf_py_test( size = "small", srcs = ["layers_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":layers", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index a6f205676ca..04ae43c1879 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -884,7 +884,6 @@ class ConvLSTM2D(ConvRNN2D): self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): - self._maybe_reset_cell_dropout_mask(self.cell) return super(ConvLSTM2D, self).call(inputs, mask=mask, training=training, diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 212ce42ddaa..6772aba605e 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -54,6 +54,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import dispatch @@ -765,9 +766,10 @@ class Lambda(Layer): The main reason to subclass `tf.keras.layers.Layer` instead of using a `Lambda` layer is saving and inspecting a Model. `Lambda` layers - are saved by serializing the Python bytecode, whereas subclassed - Layers can be saved via overriding their `get_config` method. Overriding - `get_config` improves the portability of Models. Models that rely on + are saved by serializing the Python bytecode, which is fundamentally + non-portable. They should only be loaded in the same environment where + they were saved. Subclassed layers can be saved in a more portable way + by overriding their `get_config` method. Models that rely on subclassed Layers are also often easier to visualize and reason about. Examples: @@ -1541,3 +1543,261 @@ for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access array_ops.boolean_mask, array_ops.boolean_mask_v2]: TFSlicingOpDispatcher(slicing_op).register(slicing_op) + + +class InstanceProperty(Layer): + """Wraps an instance property access (e.g. `x.foo`) in a Keras Layer. + + This layer takes an attribute name `attr_name` in the constructor and, + when called on input tensor `obj` returns `obj.attr_name`. + + KerasTensors specialized for specific extension types use it to + represent instance property accesses on the represented object in the + case where the property needs to be dynamically accessed as opposed to + being statically computed from the typespec, e.g. + + x = keras.Input(..., ragged=True) + out = x.flat_values + """ + + @trackable.no_automatic_dependency_tracking + def __init__(self, attr_name, **kwargs): + self.attr_name = attr_name + + if 'name' not in kwargs: + kwargs['name'] = K.unique_object_name( + 'input.' + self.attr_name, zero_based=True, avoid_observed_names=True) + kwargs['autocast'] = False + + # Do not individually trace op layers in the SavedModel. + self._must_restore_from_config = True + + super(InstanceProperty, self).__init__(**kwargs) + + # Preserve all argument data structures when saving/loading a config + # (e.g., don't unnest lists that contain one element) + self._preserve_input_structure_in_config = True + + def call(self, obj): + return getattr(obj, self.attr_name) + + def get_config(self): + config = { + 'attr_name': self.attr_name + } + base_config = super(InstanceProperty, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**config) + + +class InstanceMethod(InstanceProperty): + """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer. + + This layer takes an attribute name `attr_name` in the constructor and, + when called on input tensor `obj` with additional arguments `args` and + `kwargs` returns `obj.attr_name(*args, **kwargs)`. + + KerasTensors specialized for specific extension types use it to + represent dynamic instance method calls on the represented object, e.g. + + x = keras.Input(..., ragged=True) + new_values = keras.Input(...) + out = x.with_values(new_values) + """ + + def call(self, obj, args, kwargs): + method = getattr(obj, self.attr_name) + return method(*args, **kwargs) + + +def _delegate_property(keras_tensor_cls, property_name): # pylint: disable=invalid-name + """Register property on a KerasTensor class. + + Calling this multiple times with the same arguments should be a no-op. + + This method exposes a property on the KerasTensor class that will use an + `InstanceProperty` layer to access the property on the represented + intermediate values in the model. + + Arguments: + keras_tensor_cls: The KerasTensor subclass that should expose the property. + property_name: The name of the property to expose and delegate to the + represented (Composite)Tensor. + """ + # We use a lambda because we can't create a Keras layer at import time + # due to dynamic layer class versioning. + property_access = property(lambda self: InstanceProperty(property_name)(self)) # pylint: disable=unnecessary-lambda + setattr(keras_tensor_cls, property_name, property_access) + + +def _delegate_method(keras_tensor_cls, method_name): # pylint: disable=invalid-name + """Register method on a KerasTensor class. + + Calling this function times with the same arguments should be a no-op. + + This method exposes an instance method on the KerasTensor class that will use + an `InstanceMethod` layer to run the desired method on the represented + intermediate values in the model. + + Arguments: + keras_tensor_cls: The KerasTensor subclass that should expose the property. + method_name: The name of the method to expose and delegate to the + represented (Composite)Tensor. + """ + def delegate(self, *args, **kwargs): + return InstanceMethod(method_name)(self, args, kwargs) + setattr(keras_tensor_cls, method_name, delegate) + +# We do not support the `uniform_row_length` property because it +# returns either `None` or an int tensor, and code that relies on it tends +# to check `is None` directly. Delegating it here would always return a +# `KerasTensor`, regardless of what can be statically inferred. This would +# never equal `None`, breaking code that expects it to be partially-static +# in unpredictable ways. +for ragged_property in [ + 'values', + 'flat_values', + 'row_splits', + 'nested_row_splits' +]: + _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property) + +for ragged_method_name in [ + 'value_rowids', + 'nested_value_rowids', + 'nrows', + 'row_starts', + 'row_limits', + 'row_lengths', + 'nested_row_lengths', + 'bounding_shape', + 'with_values', + 'with_flat_values', + 'with_row_splits_dtype', + 'merge_dims', + 'to_tensor', + 'to_sparse', +]: + _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name) + +for sparse_property in [ + 'indices', + 'values', +]: + _delegate_property(keras_tensor.SparseKerasTensor, sparse_property) + +for sparse_method in [ + 'with_values', +]: + _delegate_method(keras_tensor.SparseKerasTensor, sparse_method) + + +class ClassMethod(Layer): + """Wraps a TF API Class's class method in a `Layer` object. + + It is inserted by the Functional API construction whenever users call + a supported TF Class's class method on KerasTensors. + + This is useful in the case where users do something like: + x = keras.Input(...) + y = keras.Input(...) + out = tf.RaggedTensor.from_row_splits(x, y) + """ + + @trackable.no_automatic_dependency_tracking + def __init__(self, cls_ref, method_name, **kwargs): + self.cls_ref = cls_ref + self.method_name = method_name + self.cls_symbol = ( + get_canonical_name_for_symbol(self.cls_ref, + add_prefix_to_v1_names=True) or + get_canonical_name_for_symbol(self.cls_ref, + api_name='keras', + add_prefix_to_v1_names=True)) + if 'name' not in kwargs: + kwargs['name'] = K.unique_object_name( + 'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True, + avoid_observed_names=True) + kwargs['autocast'] = False + + # Do not individually trace op layers in the SavedModel. + self._must_restore_from_config = True + + super(ClassMethod, self).__init__(**kwargs) + + # Preserve all argument data structures when saving/loading a config + # (e.g., don't unnest lists that contain one element) + self._preserve_input_structure_in_config = True + + self._expects_training_arg = False + self._expects_mask_arg = False + + def call(self, args, kwargs): + return getattr(self.cls_ref, self.method_name)(*args, **kwargs) + + def get_config(self): + if not self.cls_symbol: + raise ValueError('This Keras class method conversion tried to convert ' + 'a method belonging to class %s, a class ' + 'that is not an exposed in the TensorFlow API. ' + 'To ensure cross-version compatibility of Keras models ' + 'that use op layers, only op layers produced from ' + 'exported TF API symbols can be serialized.' + % self.cls_symbol) + config = { + 'cls_symbol': self.cls_symbol, + 'method_name': self.method_name + } + + base_config = super(ClassMethod, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + symbol_name = config.pop('cls_symbol') + cls_ref = get_symbol_from_name(symbol_name) + if not cls_ref: + raise ValueError( + 'TF symbol `tf.%s` could not be found.' % symbol_name) + + config['cls_ref'] = cls_ref + + return cls(**config) + + +class TFClassMethodDispatcher(dispatch.OpDispatcher): + """A class method dispatcher that allows building a functional model with TF class methods.""" + + def __init__(self, cls, method_name): + self.cls = cls + self.method_name = method_name + + def handle(self, args, kwargs): + """Handle the specified operation with the specified arguments.""" + if any( + isinstance(x, keras_tensor.KerasTensor) + for x in nest.flatten([args, kwargs])): + return ClassMethod(self.cls, self.method_name)(args[1:], kwargs) + else: + return self.NOT_SUPPORTED + +for ragged_class_method in [ + 'from_value_rowids', + 'from_row_splits', + 'from_row_lengths', + 'from_row_starts', + 'from_row_limits', + 'from_uniform_row_length', + 'from_nested_value_rowids', + 'from_nested_row_splits', + 'from_nested_row_lengths', + 'from_tensor', + 'from_sparse', +]: + TFClassMethodDispatcher( + ragged_tensor.RaggedTensor, ragged_class_method).register( + getattr(ragged_tensor.RaggedTensor, ragged_class_method)) diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 3f113bf8cbb..ff346737edc 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import core -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 9b4a0622daa..2f73074f4a0 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -194,7 +194,7 @@ class Embedding(Layer): out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs) else: out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs) - if self._dtype_policy.should_cast_variables: + if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype: # Instead of casting the variable as in most layers, cast the output, as # this is mathematically equivalent but is faster. out = math_ops.cast(out, self._dtype_policy.compute_dtype) diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py index fd468bd15b1..50ea36d1c8a 100644 --- a/tensorflow/python/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/layers/embeddings_test.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 2809cbb0108..0737fe11712 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -80,8 +80,8 @@ class BatchNormalizationBase(Layer): inference data*. Arguments: - axis: Integer, the axis that should be normalized (typically the features - axis). For instance, after a `Conv2D` layer with + axis: Integer or a list of integers, the axis that should be normalized + (typically the features axis). For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. momentum: Momentum for the moving average. epsilon: Small float added to variance to avoid dividing by zero. diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index 79ecc3c3fe1..a98db36ceea 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -31,7 +31,7 @@ from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import normalization from tensorflow.python.keras.layers import normalization_v2 -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index bca8898fcca..6b828aea24d 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -343,9 +343,6 @@ cuda_py_test( srcs = ["category_crossing_test.py"], python_version = "PY3", shard_count = 4, - tags = [ - "no_windows", # b/149031156 - ], deps = [ ":category_crossing", "//tensorflow/python:array_ops", @@ -590,6 +587,10 @@ tf_py_test( size = "small", srcs = ["normalization_test.py"], python_version = "PY3", + tags = [ + "broken", # b/170974360 + "noasan", # TODO(b/337374867) fails with -fsanitize=null + ], deps = [ ":normalization", ":preprocessing_test_utils", @@ -640,7 +641,6 @@ tf_py_test( name = "table_utils_test", srcs = ["table_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":table_utils", "//tensorflow/python:client_testlib", @@ -657,9 +657,6 @@ tf_py_test( srcs = ["text_vectorization_test.py"], python_version = "PY3", shard_count = 4, - tags = [ - "noasan", #TODO(b/161376526): Enable when bug fix lands. - ], deps = [ ":preprocessing_test_utils", ":text_vectorization", @@ -701,6 +698,7 @@ tf_py_test( name = "reduction_test", srcs = ["reduction_test.py"], python_version = "PY3", + tags = ["notsan"], # TODO(b/170783154) deps = [ ":reduction", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 88693c7fa25..0935d86cc5f 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -18,7 +18,6 @@ tf_py_test( name = "category_encoding_benchmark", srcs = ["category_encoding_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -36,7 +35,6 @@ tf_py_test( name = "category_crossing_benchmark", srcs = ["category_crossing_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -55,7 +53,6 @@ tf_py_test( name = "hashing_benchmark", srcs = ["hashing_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -74,7 +71,6 @@ tf_py_test( name = "index_lookup_adapt_benchmark", srcs = ["index_lookup_adapt_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -92,7 +88,6 @@ tf_py_test( name = "normalization_adapt_benchmark", srcs = ["normalization_adapt_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -112,7 +107,6 @@ cuda_py_test( name = "image_preproc_benchmark", srcs = ["image_preproc_benchmark.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/layers/preprocessing/integer_lookup.py b/tensorflow/python/keras/layers/preprocessing/integer_lookup.py index d0ffc987e01..c15cce3b050 100644 --- a/tensorflow/python/keras/layers/preprocessing/integer_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/integer_lookup.py @@ -217,3 +217,9 @@ class IntegerLookup(index_lookup.IndexLookup): base_config["oov_value"] = base_config["oov_token"] del base_config["oov_token"] return base_config + + def set_vocabulary(self, vocab): + if isinstance(vocab, str): + vocab = table_utils.get_vocabulary_from_file(vocab) + vocab = [int(v) for v in vocab] + super().set_vocabulary(vocab) diff --git a/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py index 0b71c6aaecc..4b791fc47aa 100644 --- a/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py @@ -426,6 +426,21 @@ class IntegerLookupVocabularyTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_int_output_explicit_vocab_from_file_via_setter(self): + vocab_list = [42, 1138, 725, 1729] + vocab_path = self._write_to_temp_file("vocab_file", vocab_list) + + input_array = np.array([[42, 1138, 725, 1729], [1729, 725, 42, 203]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()() + layer.set_vocabulary(vocab_path) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + def test_non_unique_vocab_fails(self): vocab_data = [42, 1138, 725, 1729, 1729] with self.assertRaisesRegex(ValueError, ".*repeated term.*1729.*"): diff --git a/tensorflow/python/keras/layers/preprocessing/string_lookup.py b/tensorflow/python/keras/layers/preprocessing/string_lookup.py index c70ac50dd07..679f4fd5a71 100644 --- a/tensorflow/python/keras/layers/preprocessing/string_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/string_lookup.py @@ -212,3 +212,8 @@ class StringLookup(index_lookup.IndexLookup): # This is required because the MutableHashTable doesn't preserve insertion # order, but we rely on the order of the array to assign indices. return [x.decode(self.encoding) for _, x in sorted(zip(values, keys))] + + def set_vocabulary(self, vocab): + if isinstance(vocab, str): + vocab = table_utils.get_vocabulary_from_file(vocab, self.encoding) + super().set_vocabulary(vocab) diff --git a/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py index 2b45b59fcf4..0ca10ff574b 100644 --- a/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py @@ -177,6 +177,22 @@ class StringLookupVocabularyTest(keras_parameterized.TestCase, output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_int_output_explicit_vocab_from_file_via_setter(self): + vocab_list = ["earth", "wind", "and", "fire"] + vocab_path = self._write_to_temp_file("vocab_file", vocab_list) + + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()() + layer.set_vocabulary(vocab_path) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + def test_non_unique_vocab_fails(self): vocab_data = ["earth", "wind", "and", "fire", "fire"] with self.assertRaisesRegex(ValueError, ".*repeated term.*fire.*"): diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 6449d8afaf7..096fd489ded 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -28,6 +28,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.keras.layers.preprocessing import string_lookup +from tensorflow.python.keras.layers.preprocessing import table_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops @@ -481,7 +482,8 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): it. Arguments: - vocab: An array of string tokens. + vocab: An array of string tokens, or a path to a file containing one + token per line. df_data: An array of document frequency data. Only necessary if the layer output_mode is TFIDF. oov_df_value: The document frequency of the OOV token. Only necessary if @@ -506,6 +508,21 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): "be changed after the layer is " "called.").format(mode=self._output_mode)) + # Handle reading from a file. We can't do this via TF-IDF, as we don't have + # a standard format - we error out and ask our users to parse the file + # themselves. + if isinstance(vocab, str): + if self._output_mode == TFIDF: + raise RuntimeError("Setting vocabulary directly from a file is not " + "supported in TF-IDF mode, since this layer cannot " + "read files containing TF-IDF weight data. Please " + "read the file using Python and set the vocab " + "and weights by passing lists or arrays to the " + "set_vocabulary function's `vocab` and `df_data` " + "args.") + vocab = table_utils.get_vocabulary_from_file( + vocab, self._index_lookup_layer.encoding) + self._index_lookup_layer.set_vocabulary(vocab) # When doing raw or integer output, we don't have a Vectorize layer to diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index a1f9f54a39f..adab52f6dda 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -44,6 +44,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_string_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -414,6 +415,15 @@ class TextVectorizationPreprocessingTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): + def _write_to_temp_file(self, file_name, vocab_list): + vocab_path = os.path.join(self.get_temp_dir(), file_name + ".txt") + with gfile.GFile(vocab_path, "w") as writer: + for vocab in vocab_list: + writer.write(vocab + "\n") + writer.flush() + writer.close() + return vocab_path + def test_summary_before_adapt(self): input_data = keras.Input(shape=(None,), dtype=dtypes.string) layer = get_layer_class()( @@ -709,6 +719,46 @@ class TextVectorizationPreprocessingTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_vocab_setting_via_init_file(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + vocab_path = self._write_to_temp_file("vocab_file", vocab_data) + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=None, + standardize=None, + split=None, + output_mode=text_vectorization.INT, + vocabulary=vocab_path) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + def test_vocab_setting_via_setter(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + vocab_path = self._write_to_temp_file("vocab_file", vocab_data) + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=None, + standardize=None, + split=None, + output_mode=text_vectorization.INT) + layer.set_vocabulary(vocab_path) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + @keras_parameterized.run_all_keras_modes class TextVectorizationDistributionTest( diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 72ba1fbcc58..0748b5bca04 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -1568,7 +1568,6 @@ class SimpleRNN(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self._maybe_reset_cell_dropout_mask(self.cell) return super(SimpleRNN, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -2103,7 +2102,6 @@ class GRU(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self._maybe_reset_cell_dropout_mask(self.cell) return super(GRU, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -2778,7 +2776,6 @@ class LSTM(RNN): self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): - self._maybe_reset_cell_dropout_mask(self.cell) return super(LSTM, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -3039,7 +3036,8 @@ def _caching_device(rnn_cell): 'consider updating your code to remove tf.while_loop if ' 'possible.') return None - if rnn_cell._dtype_policy.should_cast_variables: + if (rnn_cell._dtype_policy.compute_dtype != + rnn_cell._dtype_policy.variable_dtype): logging.warn('Variable read device caching has been disabled since it ' 'doesn\'t work with the mixed precision API. This is ' 'likely to cause a slowdown for RNN training due to ' diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 8eddc14f5f8..d3c9111cd65 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -67,7 +67,7 @@ _CUDNN_NOT_AVAILABLE_MSG = ('Layer %s will not use cuDNN kernel since it ' def _use_new_code(): - return True + return False # TODO(b/169707691): The wrapper can be removed if TFLite doesn't need to rely diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 6798e5c8fff..c2accf24e58 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -93,8 +93,8 @@ class TimeDistributed(Wrapper): with `channels_last` data format, across 10 timesteps. The batch input shape is `(32, 10, 128, 128, 3)`. - You can then use `TimeDistributed` to apply a `Conv2D` layer to each of the - 10 timesteps, independently: + You can then use `TimeDistributed` to apply the same `Conv2D` layer to each + of the 10 timesteps, independently: >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3)) >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3)) @@ -102,6 +102,9 @@ class TimeDistributed(Wrapper): >>> outputs.shape TensorShape([None, 10, 126, 126, 64]) + Because `TimeDistributed` applies the same instance of `Conv2D` to each of the + timestamps, the same set of weights are used at each timestamp. + Arguments: layer: a `tf.keras.layers.Layer` instance. diff --git a/tensorflow/python/keras/legacy_tf_layers/BUILD b/tensorflow/python/keras/legacy_tf_layers/BUILD index 45ccd958db0..a8751302681 100644 --- a/tensorflow/python/keras/legacy_tf_layers/BUILD +++ b/tensorflow/python/keras/legacy_tf_layers/BUILD @@ -34,7 +34,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/keras:backend", "//tensorflow/python/keras/engine:base_layer", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/keras/mixed_precision:policy", "//tensorflow/python/training/tracking:base", ], ) @@ -88,7 +88,6 @@ tf_py_test( srcs = ["base_test.py"], main = "base_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":core", ":layers_base", @@ -119,7 +118,6 @@ tf_py_test( srcs = ["core_test.py"], main = "core_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":core", "//tensorflow/python:array_ops", @@ -171,7 +169,6 @@ tf_py_test( main = "pooling_test.py", python_version = "PY3", tags = ["no_rocm"], - tfrt_enabled = True, deps = [ ":pooling", "//tensorflow/python:array_ops", @@ -189,7 +186,6 @@ cuda_py_test( main = "normalization_test.py", python_version = "PY3", shard_count = 10, - tfrt_enabled = True, deps = [ ":convolutional", ":normalization", diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py index 8052651efa7..2a3f477456f 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base.py +++ b/tensorflow/python/keras/legacy_tf_layers/base.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables @@ -210,6 +210,9 @@ class Layer(base_layer.Layer): if 'autocast' not in kwargs: kwargs['autocast'] = False + # Mark that legacy layers should not be instrumented as Keras usage + self._disable_keras_instrumentation = True + super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) diff --git a/tensorflow/python/keras/legacy_tf_layers/base_test.py b/tensorflow/python/keras/legacy_tf_layers/base_test.py index 2c9810c4109..90d57fae407 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/base_test.py @@ -60,6 +60,9 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase): layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) + # Assert that the layer was not instrumented as a Keras layer + self.assertFalse(layer._instrumented_keras_api) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def testInt64Layer(self): layer = base_layers.Layer(name='my_layer', dtype='int64') @@ -83,6 +86,8 @@ class BaseLayerTest(test.TestCase, parameterized.TestCase): with base_layers.keras_style_scope(): layer = base_layers.Layer(name='my_layer') + # Assert that the layer was not instrumented as a Keras layer + self.assertFalse(layer._instrumented_keras_api) # Test basic variable creation. with backend.name_scope('bar'): variable = layer.add_variable( diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/BUILD similarity index 96% rename from tensorflow/python/keras/mixed_precision/experimental/BUILD rename to tensorflow/python/keras/mixed_precision/BUILD index d1bd18f85a5..b12ce250eef 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/BUILD @@ -69,13 +69,14 @@ py_test( ], python_version = "PY3", srcs_version = "PY2AND3", + tags = ["no_rocm"], deps = [ ":policy", "//tensorflow/python:client_testlib", "//tensorflow/python:platform_test", "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", - "//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer", + "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer", "//tensorflow/python/keras/optimizer_v2", ], ) @@ -256,8 +257,8 @@ cuda_py_test( size = "medium", srcs = ["keras_test.py"], data = [ - "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2", - "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2", + "//tensorflow/python/keras/mixed_precision/testdata:lso_ckpt_tf2.2", + "//tensorflow/python/keras/mixed_precision/testdata:lso_savedmodel_tf2.2", ], python_version = "PY3", shard_count = 10, diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/autocast_variable.py similarity index 95% rename from tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py rename to tensorflow/python/keras/mixed_precision/autocast_variable.py index b33ea3a0b33..3cacee0cb82 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/autocast_variable.py @@ -51,9 +51,6 @@ class AutoCastVariable(variables.Variable, core.Tensor): >>> with enable_auto_cast_variables(tf.float16): ... tf.identity(v).dtype tf.float16 - >>> with enable_auto_cast_variables(tf.float16): - ... v.dtype # v.dtype also changes under the context manager - tf.float16 The purpose of this class is to allow Keras layers to create variables in float32, and automatically cast them to float16 or bfloat16 when the layer is @@ -82,38 +79,42 @@ class AutoCastVariable(variables.Variable, core.Tensor): def _should_cast(self): """Returns True if this variable should be casted when accessed.""" autocast_dtype = getattr(_autocast_dtype, 'dtype', None) - return autocast_dtype is not None and self.true_dtype != autocast_dtype + return autocast_dtype is not None and self.dtype != autocast_dtype @property def dtype(self): - """The dtype this variable will be casted to when read.""" - dtype = getattr(_autocast_dtype, 'dtype', None) - return dtype or self._variable.dtype + """The dtype of the underlying variable, before any casts are done.""" + return self._variable.dtype @property def true_dtype(self): - """The dtype of the underlying variable, before any casts are done.""" + """Deprecated alias of `dtype`.""" return self._variable.dtype + @property + def _cast_dtype(self): + dtype = getattr(_autocast_dtype, 'dtype', None) + return dtype or self._variable.dtype + def value(self): val = self._variable.value() if not self._should_cast(): return val - return math_ops.cast(val, self.dtype) + return math_ops.cast(val, self._cast_dtype) def read_value(self): val = self._variable.read_value() - return math_ops.cast(val, self.dtype) + return math_ops.cast(val, self._cast_dtype) def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" val = self._variable.sparse_read(indices, name=name) - return math_ops.cast(val, self.dtype) + return math_ops.cast(val, self._cast_dtype) def gather_nd(self, indices, name=None): """Gather slices of the variable into a Tensor.""" val = self._variable.gather_nd(indices, name=name) - return math_ops.cast(val, self.dtype) + return math_ops.cast(val, self._cast_dtype) def __getattr__(self, name): return getattr(self._variable, name) @@ -124,13 +125,14 @@ class AutoCastVariable(variables.Variable, core.Tensor): return ops.convert_to_tensor(self._variable, dtype, name, as_ref) # TODO(reedwm): Support as_ref? assert not as_ref - if dtype is not None and not dtype.is_compatible_with(self.dtype): + if dtype is not None and not dtype.is_compatible_with(self._cast_dtype): raise ValueError( - 'Incompatible type conversion requested to type {!r} for variable ' - 'of type {!r}'.format(dtype.name, self.dtype.name)) + 'Incompatible type conversion requested to type {!r} for ' + 'AutoCastVariable which is casted to type {!r}'.format( + dtype.name, self._cast_dtype.name)) val = ops.convert_to_tensor_v2_with_dispatch( self._variable, dtype=self._variable.dtype, name=name) - return math_ops.cast(val, self.dtype) + return math_ops.cast(val, self._cast_dtype) def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" @@ -139,13 +141,13 @@ class AutoCastVariable(variables.Variable, core.Tensor): def __repr__(self): if context.executing_eagerly() and not self._in_graph_mode: repr_str = ("') return repr_str.format( v=self, np_repr=ops.numpy_text(self.read_value(), is_repr=True)) else: repr_str = ("') + 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>') return repr_str.format(v=self) # Method delegations: We delegate the following methods to self._variable. @@ -504,7 +506,8 @@ def create_autocast_variable(variable, op=None): # pylint: disable=missing-format-attribute return ('' + 'dtype_to_cast_to={v._cast_dtype.name} ' + 'inner_variable={v._variable}>' ).format(v=self) # pylint: enable=missing-format-attribute diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py similarity index 95% rename from tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py rename to tensorflow/python/keras/mixed_precision/autocast_variable_test.py index 738333039da..c21ff865205 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py @@ -36,7 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import test_combinations as combinations -from tensorflow.python.keras.mixed_precision.experimental import autocast_variable +from tensorflow.python.keras.mixed_precision import autocast_variable from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops @@ -77,7 +77,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): # within auto cast scope of different dtype with autocast_variable.enable_auto_cast_variables(dtypes.float16): - self.assertEqual(x.dtype, dtypes.float16) + self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.value().dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) self.assertEqual(array_ops.identity(x).dtype, dtypes.float16) @@ -111,14 +111,11 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.evaluate(x.initializer) with autocast_variable.enable_auto_cast_variables(dtypes.float16): - self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) with autocast_variable.enable_auto_cast_variables(dtypes.float32): - self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) - self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) @ds_combinations.generate(maybe_distribute) @@ -133,7 +130,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): dtype = dtypes.float16 with autocast_variable.enable_auto_cast_variables(dtype): - self.assertEqual(x.dtype, dtypes.float16) + self.assertEqual(x.dtype, dtypes.float32) self.assertIsInstance(x.dtype, dtypes.DType) self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType) @@ -153,7 +150,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): def evaluate(var): self.assertIsInstance(var, autocast_variable.AutoCastVariable) - self.assertEqual(var.dtype, read_dtype) + self.assertEqual(array_ops.identity(var).dtype, read_dtype) # pylint: disable=cell-var-from-loop return self.evaluate(var) x = get_var(7., dtypes.float32) @@ -415,13 +412,13 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): self.evaluate(x.initializer) with autocast_variable.enable_auto_cast_variables(dtypes.float16): - self.assertEqual(x.dtype, dtypes.float16) + self.assertEqual(array_ops.identity(x).dtype, dtypes.float16) # New threads should not see the modified value of the autocast dtype. var_dtype = None def f(): nonlocal var_dtype - var_dtype = x.dtype + var_dtype = x._cast_dtype thread = threading.Thread(target=f) thread.start() thread.join() @@ -465,24 +462,26 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): self.assertStartsWith( repr(x), - "" + "" ) with autocast_variable.enable_auto_cast_variables(dtypes.float16): self.assertEqual( repr(x), - "" + "" ) def test_repr_distributed(self): @@ -494,12 +493,14 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): if use_policy: self.assertRegex( repr(x).replace('\n', ' '), - '') else: self.assertRegex( repr(x).replace('\n', ' '), - '') @parameterized.named_parameters( diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/device_compatibility_check.py similarity index 100% rename from tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py rename to tensorflow/python/keras/mixed_precision/device_compatibility_check.py diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py b/tensorflow/python/keras/mixed_precision/device_compatibility_check_test.py similarity index 98% rename from tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py rename to tensorflow/python/keras/mixed_precision/device_compatibility_check_test.py index ccefa250d2d..381b054fa58 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py +++ b/tensorflow/python/keras/mixed_precision/device_compatibility_check_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import re from tensorflow.python.keras import combinations -from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check +from tensorflow.python.keras.mixed_precision import device_compatibility_check from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging diff --git a/tensorflow/python/keras/mixed_precision/experimental/__init__.py b/tensorflow/python/keras/mixed_precision/experimental/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tensorflow/python/keras/mixed_precision/experimental/get_layer_policy.py b/tensorflow/python/keras/mixed_precision/get_layer_policy.py similarity index 89% rename from tensorflow/python/keras/mixed_precision/experimental/get_layer_policy.py rename to tensorflow/python/keras/mixed_precision/get_layer_policy.py index 47826b48a97..dec706fde1f 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/get_layer_policy.py +++ b/tensorflow/python/keras/mixed_precision/get_layer_policy.py @@ -29,13 +29,16 @@ from tensorflow.python.util.tf_export import keras_export def get_layer_policy(layer): """Returns the dtype policy of a layer. + Warning: This function is deprecated. Use + `tf.keras.layers.Layer.dtype_policy` instead. + Args: layer: A `tf.keras.layers.Layer`. Returns: - The `tf.keras.mixed_precision.experimental.Policy` of the layer. + The `tf.keras.mixed_precision.Policy` of the layer. """ if not isinstance(layer, base_layer.Layer): raise ValueError('get_policy can only be called on a layer, but got: %s' % (layer,)) - return layer._dtype_policy # pylint: disable=protected-access + return layer.dtype_policy diff --git a/tensorflow/python/keras/mixed_precision/experimental/get_layer_policy_test.py b/tensorflow/python/keras/mixed_precision/get_layer_policy_test.py similarity index 91% rename from tensorflow/python/keras/mixed_precision/experimental/get_layer_policy_test.py rename to tensorflow/python/keras/mixed_precision/get_layer_policy_test.py index f38bdfaf482..ae1ac94055c 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/get_layer_policy_test.py +++ b/tensorflow/python/keras/mixed_precision/get_layer_policy_test.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.layers import core -from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import get_layer_policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/keras_test.py similarity index 83% rename from tensorflow/python/keras/mixed_precision/experimental/keras_test.py rename to tensorflow/python/keras/mixed_precision/keras_test.py index dd754e87bb4..d788f0005b0 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/keras_test.py @@ -44,10 +44,10 @@ from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.layers import core -from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer -from tensorflow.python.keras.mixed_precision.experimental import policy -from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.keras.mixed_precision import get_layer_policy +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import policy +from tensorflow.python.keras.mixed_precision import test_util as mp_test_util from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.saving import save from tensorflow.python.keras.utils import generic_utils @@ -141,7 +141,11 @@ class KerasLayerTest(keras_parameterized.TestCase): y = layer(x) self.assertEqual(layer.v.dtype, dtypes.float32) self.assertEqual(y.dtype, dtype) + self.assertEqual(layer.dtype_policy.name, policy_name) + self.assertIsInstance(layer.dtype_policy, policy.Policy) + self.assertEqual(layer.compute_dtype, dtype) self.assertEqual(layer.dtype, dtypes.float32) + self.assertEqual(layer.variable_dtype, dtypes.float32) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, policy_name) self.evaluate(variables.global_variables_initializer()) @@ -226,7 +230,11 @@ class KerasLayerTest(keras_parameterized.TestCase): # Passing a Policy to dtype overrides the global Policy layer = mp_test_util.MultiplyLayer( assert_type=dtypes.float64, dtype=policy.Policy('float64')) - self.assertEqual(layer.dtype, 'float64') + self.assertEqual(layer.dtype_policy.name, 'float64') + self.assertIsInstance(layer.dtype_policy, policy.Policy) + self.assertEqual(layer.compute_dtype, dtypes.float64) + self.assertEqual(layer.dtype, dtypes.float64) + self.assertEqual(layer.variable_dtype, dtypes.float64) self.assertEqual(layer(x).dtype, dtypes.float64) self.assertEqual(layer.v.dtype, dtypes.float64) @@ -344,32 +352,10 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float32') - - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16', - loss_scale=None)) config = layer.get_config() self.assertEqual(config['dtype'], {'class_name': 'Policy', - 'config': {'name': 'mixed_float16', - 'loss_scale': None}}) - layer = mp_test_util.MultiplyLayer.from_config(config) - self.assertEqual(layer.dtype, 'float32') - self.assertEqual(layer(x).dtype, 'float16') - self.assertEqual(layer.v.dtype, 'float32') - - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('float64', - loss_scale=2.)) - config = layer.get_config() - self.assertEqual(config['dtype'], - {'class_name': 'Policy', - 'config': {'name': 'float64', - 'loss_scale': { - 'class_name': 'FixedLossScale', - 'config': {'loss_scale_value': 2.0}}}}) - layer = mp_test_util.MultiplyLayer.from_config(config) - self.assertEqual(layer.dtype, 'float64') - self.assertEqual(layer(x).dtype, 'float64') - self.assertEqual(layer.v.dtype, 'float64') + 'config': {'name': 'mixed_float16'}}) layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) config = layer.get_config() @@ -383,11 +369,53 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer(x).dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32') - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer', - loss_scale=2.)) + @parameterized.named_parameters(*TESTCASES) + def test_config_policy_v1(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16', + loss_scale=None)) config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', + 'config': {'name': 'mixed_float16', + 'loss_scale': None}}) + layer = mp_test_util.MultiplyLayer.from_config(config) + self.assertEqual(layer.dtype, 'float32') + self.assertEqual(layer(x).dtype, 'float16') + self.assertEqual(layer.v.dtype, 'float32') + # Restoring a PolicyV1 silently converts it to a Policy and drops the loss + # scale. + self.assertEqual(type(layer.dtype_policy), policy.Policy) + config = layer.get_config() + # The loss_scale is silently dropped self.assertEqual(config['dtype'], {'class_name': 'Policy', + 'config': {'name': 'mixed_float16'}}) + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64', + loss_scale=2.)) + config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', + 'config': {'name': 'float64', + 'loss_scale': { + 'class_name': 'FixedLossScale', + 'config': {'loss_scale_value': 2.0}}}}) + layer = mp_test_util.MultiplyLayer.from_config(config) + self.assertEqual(layer.dtype, 'float64') + self.assertEqual(layer(x).dtype, 'float64') + self.assertEqual(layer.v.dtype, 'float64') + self.assertEqual(type(layer.dtype_policy), policy.Policy) + config = layer.get_config() + self.assertEqual(config['dtype'], 'float64') + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer', + loss_scale=2.)) + config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', 'config': {'name': '_infer', 'loss_scale': { 'class_name': 'FixedLossScale', @@ -396,6 +424,9 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer.dtype, None) self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float16') + self.assertEqual(type(layer.dtype_policy), policy.Policy) + config = layer.get_config() + self.assertEqual(config['dtype'], 'float16') def test_delete_variable(self): layer = base_layer.Layer(dtype=policy.Policy('mixed_float16')) @@ -501,6 +532,11 @@ class KerasModelTest(keras_parameterized.TestCase): 'strategy_fn': create_mirrored_strategy, 'save_format': 'h5', 'use_regularizer': True, + }, { + 'testcase_name': 'saved_model_v1_policy', + 'strategy_fn': create_mirrored_strategy, + 'use_v1_policy': True, + 'save_format': 'tf', }) def test_model(self, strategy_fn, @@ -509,19 +545,23 @@ class KerasModelTest(keras_parameterized.TestCase): policy_name='mixed_float16', get_config=False, save_format=None, - use_input_spec=False): + use_input_spec=False, + use_v1_policy=False): self._skip_if_strategy_unsupported(strategy_fn) self._skip_if_save_format_unsupported(save_format) - regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer - else None) + if use_regularizer: + weight_regularizer = mp_test_util.IdentityRegularizer() + activity_regularizer = mp_test_util.ReduceSumRegularizer() + else: + weight_regularizer = activity_regularizer = None with strategy_fn().scope(): - # Pass loss_scale=None, as this test will fail if the DynamicLossScale - # skips applying gradients for a step - with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)): + cls = policy.PolicyV1 if use_v1_policy else policy.Policy + with policy.policy_scope(cls(policy_name)): layer = mp_test_util.MultiplyLayer( assert_type=dtypes.float16, use_operator=use_operator, - regularizer=regularizer, + regularizer=weight_regularizer, + activity_regularizer=activity_regularizer, input_shape=(1,)) if use_input_spec: layer.input_spec = input_spec.InputSpec(shape=(None, 1)) @@ -543,6 +583,10 @@ class KerasModelTest(keras_parameterized.TestCase): # the variable will not change. So this tests the learning rate not # applied to a float16 value, but instead the float32 variable. opt = gradient_descent.SGD(2**-14) + # Use a fixed loss scale, as this test will fail if gradients are + # skipped for a step due to dynamic loss scaling. + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=8) model.compile( opt, loss=loss_fn, @@ -556,8 +600,9 @@ class KerasModelTest(keras_parameterized.TestCase): # from it. expected = 1 - 2**-14 if use_regularizer: - # Regularizer adds another 2 ** -14 to the gradient. - expected -= 2**-14 + # Weight and activity regularizer each add another 2 ** -14 to the + # gradient. + expected -= 2 * 2**-14 self.assertEqual(backend.eval(layer.v), expected) if save_format: @@ -574,14 +619,14 @@ class KerasModelTest(keras_parameterized.TestCase): if 'MultiplyLayer' in layer.__class__.__name__) expected = 1 - 2**-14 if use_regularizer: - expected -= 2**-14 + expected -= 2 * 2**-14 self.assertEqual(backend.eval(layer.v), expected) # Continue training, and assert variable is correct value model.fit(dataset) new_expected = expected - 2 ** -14 if use_regularizer: - new_expected -= 2 ** -14 + new_expected -= 2 * 2 ** -14 self.assertEqual(backend.eval(layer.v), new_expected) # Load saved model again, and assert variable is previous value @@ -597,6 +642,13 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16') + # Loading a model always loads with a v2 Policy, even if saved with a + # PolicyV1. + self.assertEqual(type(model.dtype_policy), policy.Policy) + self.assertEqual(layer.get_config()['dtype'], + {'class_name': 'Policy', 'config': { + 'name': 'mixed_float16'}}) + @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( { @@ -630,7 +682,8 @@ class KerasModelTest(keras_parameterized.TestCase): return math_ops.reduce_mean(y_pred) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) model.compile( opt, loss=loss_fn, @@ -669,13 +722,10 @@ class KerasModelTest(keras_parameterized.TestCase): strategy = strategy_fn() if use_loss_scaling: loss_scale = 8. - else: - loss_scale = None learning_rate = 2**-14 with strategy.scope(): - with policy.policy_scope(policy.Policy('mixed_float16', - loss_scale=loss_scale)): + with policy.policy_scope(policy.Policy('mixed_float16')): x = layers.Input(shape=(1,), batch_size=2) layer1 = mp_test_util.MultiplyLayer( assert_type=dtypes.float16, @@ -710,6 +760,9 @@ class KerasModelTest(keras_parameterized.TestCase): return math_ops.reduce_mean(y_pred) opt = gradient_descent.SGD(learning_rate) + if use_loss_scaling: + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=loss_scale) model.compile( opt, loss=loss_fn, @@ -743,6 +796,11 @@ class KerasModelTest(keras_parameterized.TestCase): 'testcase_name': 'get_config', 'strategy_fn': create_mirrored_strategy, 'get_config': True, + }, { + 'testcase_name': 'get_config_v1_lso', + 'strategy_fn': create_mirrored_strategy, + 'get_config': True, + 'use_v1_loss_scale_optimizer': True, }, { 'testcase_name': 'get_config_and_pass_loss_scale_to_policy', 'strategy_fn': create_mirrored_strategy, @@ -752,12 +810,11 @@ class KerasModelTest(keras_parameterized.TestCase): def test_dynamic_loss_scaling(self, strategy_fn, pass_loss_scale_to_policy=False, - get_config=False): + get_config=False, + use_v1_loss_scale_optimizer=False): strategy = strategy_fn() initial_loss_scale = 2. batch_size = 4 - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=2) expected_gradient = backend.variable([initial_loss_scale / batch_size], dtype=dtypes.float16) # If this variable is set to True, the model below will have NaN gradients @@ -765,10 +822,19 @@ class KerasModelTest(keras_parameterized.TestCase): with strategy.scope(): opt = gradient_descent.SGD(1.) if pass_loss_scale_to_policy: - p = policy.Policy('mixed_float16', loss_scale=loss_scale) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=2) + p = policy.PolicyV1('mixed_float16', loss_scale=loss_scale) + elif use_v1_loss_scale_optimizer: + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=2) + p = policy.Policy('mixed_float16') + opt = loss_scale_optimizer.LossScaleOptimizerV1( + opt, loss_scale) else: - p = policy.Policy('mixed_float16', loss_scale=None) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + p = policy.Policy('mixed_float16') + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=initial_loss_scale, dynamic_growth_steps=2) with policy.policy_scope(p): x = layers.Input( shape=(1,), batch_size=batch_size, dtype=dtypes.float16) @@ -835,19 +901,32 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertEqual(backend.eval(layer.v), -3) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) - def test_loss_scale_optimizer_overrides_policy_loss_scale(self): - with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + def test_loss_scale_optimizer_overrides_policy_v1_loss_scale(self): + with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)): opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=5.) x = layers.Input(shape=(1,)) y = mp_test_util.MultiplyLayer()(x) model = models.Model(x, y) model.compile(opt, loss='mse') - self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.) + self.assertEqual(self.evaluate(model.optimizer.loss_scale), 5.) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_policy_v1_without_loss_scale(self): + with policy.policy_scope(policy.PolicyV1('mixed_float16', + loss_scale=None)): + opt = gradient_descent.SGD(1.) + x = layers.Input(shape=(1,)) + y = mp_test_util.MultiplyLayer()(x) + model = models.Model(x, y) + model.compile(opt, loss='mse') + self.assertNotIsInstance(model.optimizer, + loss_scale_optimizer.LossScaleOptimizer) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_pass_invalid_optimizer_with_loss_scaling(self): - with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)): x = layers.Input(shape=(1,)) y = mp_test_util.MultiplyLayer()(x) model = models.Model(x, y) @@ -926,7 +1005,7 @@ class KerasModelTest(keras_parameterized.TestCase): def test_save_slot_variables_with_autocast_vars(self, strategy_fn, var_name='v'): - p = policy.Policy('mixed_float16', loss_scale=None) + p = policy.Policy('mixed_float16') with strategy_fn().scope(), policy.policy_scope(p): x = layers.Input(shape=(2,), batch_size=2) # Having a var_name other than 'v' tests that a fixed bug (b/134713714) @@ -938,6 +1017,8 @@ class KerasModelTest(keras_parameterized.TestCase): y = layer(x) model = models.Model(inputs=x, outputs=y) opt = gradient_descent.SGD(1., 1.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=1) model.compile( optimizer=opt, loss='mse', @@ -971,18 +1052,17 @@ class KerasModelTest(keras_parameterized.TestCase): y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x) model = models.Model(inputs=x, outputs=y) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., multiplier=2.) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=1., dynamic_growth_steps=2.) model.compile( optimizer=opt, loss='mse', run_eagerly=testing_utils.should_run_eagerly()) # Run for 3 steps (6 examples with a batch size of 2) model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) # Save model weights. save_prefix = os.path.join(self.get_temp_dir(), 'ckpt') @@ -990,20 +1070,20 @@ class KerasModelTest(keras_parameterized.TestCase): # Run model again for 1 step (2 examples with a batch size of 2) model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 4) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) + self.assertEqual(backend.get_value(opt.loss_scale), 4) + self.assertEqual(backend.get_value(opt.dynamic_counter), 0) # Load model weights and ensure loss scale weights are restored. model.load_weights(save_prefix) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) @keras_parameterized.run_all_keras_modes def test_restore_old_loss_scale_checkpoint(self): # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format # of LossScaleOptimizer changed, but old checkpoints can still be loaded opt = gradient_descent.SGD(0.1, momentum=0.1) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + opt = loss_scale_optimizer.LossScaleOptimizer(opt) model = sequential.Sequential([core.Dense(2,)]) # The checkpoint and expected values were obtained from the program in @@ -1011,9 +1091,9 @@ class KerasModelTest(keras_parameterized.TestCase): ckpt_dir = os.path.join( flags.FLAGS['test_srcdir'].value, 'org_tensorflow/tensorflow/python/keras', - 'mixed_precision/experimental/testdata/lso_ckpt_tf2.2') + 'mixed_precision/testdata/lso_ckpt_tf2.2') # ckpt_dir = test.test_src_dir_path( - # 'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2') + # 'python/keras/mixed_precision/testdata/lso_ckpt_tf2.2') model.load_weights(os.path.join(ckpt_dir, 'ckpt')) model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly()) model(np.zeros((2, 2))) # Create model weights @@ -1024,8 +1104,8 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertAllClose( self.evaluate(opt.get_slot(model.weights[0], 'momentum')), expected_slot) - self.assertEqual(self.evaluate(opt.loss_scale()), 32768) - self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 32768) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) # Check restoring works even after the model is compiled and the weights # have been created. @@ -1039,22 +1119,22 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertAllClose( self.evaluate(opt.get_slot(model.weights[0], 'momentum')), expected_slot) - self.assertEqual(self.evaluate(opt.loss_scale()), 32768) - self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 32768) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) def test_restore_old_saved_model(self): saved_model_dir = os.path.join( flags.FLAGS['test_srcdir'].value, 'org_tensorflow/tensorflow/python/keras', - 'mixed_precision/experimental/testdata/lso_savedmodel_tf2.2') + 'mixed_precision/testdata/lso_savedmodel_tf2.2') # saved_model_dir = test.test_src_dir_path( - # 'python/keras/mixed_precision/experimental/testdata/' + # 'python/keras/mixed_precision/testdata/' # 'lso_savedmodel_tf2.2') model = save.load_model(saved_model_dir) expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]]) self.assertAllClose(backend.eval(model.weights[0]), expected_kernel) - self.assertIsInstance(model.optimizer, - loss_scale_optimizer.LossScaleOptimizer) + self.assertEqual(type(model.optimizer), + loss_scale_optimizer.LossScaleOptimizer) @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( @@ -1064,6 +1144,10 @@ class KerasModelTest(keras_parameterized.TestCase): }, { 'testcase_name': 'distribute', 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'use_v1_lso', + 'strategy_fn': create_mirrored_strategy, + 'use_v1_loss_scale_optimizer': True }, { 'testcase_name': 'base_h5', 'strategy_fn': default_strategy_fn, @@ -1073,7 +1157,8 @@ class KerasModelTest(keras_parameterized.TestCase): 'strategy_fn': create_mirrored_strategy, 'h5': True, }) - def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False): + def test_save_model_with_dynamic_loss_scaling( + self, strategy_fn, h5=False, use_v1_loss_scale_optimizer=False): # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy # as well. strategy = strategy_fn() @@ -1088,18 +1173,22 @@ class KerasModelTest(keras_parameterized.TestCase): y = mp_test_util.MultiplyLayer()(x) model = models.Model(inputs=x, outputs=y) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., multiplier=2.) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + if use_v1_loss_scale_optimizer: + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1., increment_period=2.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + else: + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) model.compile( optimizer=opt, loss='mse', run_eagerly=testing_utils.should_run_eagerly()) # Run for 3 steps (6 examples with a batch size of 2) model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) (weight,) = model.trainable_weights orig_weight = backend.get_value(weight) @@ -1111,13 +1200,12 @@ class KerasModelTest(keras_parameterized.TestCase): model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2) new_weight = backend.get_value(weight) self.assertNotEqual(new_weight, orig_weight) - self.assertEqual(backend.get_value(loss_scale()), 4) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) + self.assertEqual(backend.get_value(opt.loss_scale), 4) + self.assertEqual(backend.get_value(opt.dynamic_counter), 0) # Load model weights and ensure loss scale weights are restored. model = save.load_model( save_path, custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer}) - loss_scale = model.optimizer.loss_scale (weight,) = model.trainable_weights loaded_weight = backend.get_value(weight) self.assertEqual(loaded_weight, orig_weight) @@ -1125,8 +1213,14 @@ class KerasModelTest(keras_parameterized.TestCase): # Model.save(). So we assert the loss scale either has the value when it was # saved, or the value it was initialized with. # TODO(reedwm): Always save/restore the loss scale with Model.save(). - self.assertIn(backend.get_value(loss_scale()), (1, 2)) - self.assertIn(backend.get_value(loss_scale._num_good_steps), (0, 1)) + self.assertIn(backend.get_value(model.optimizer.loss_scale), (1, 2)) + self.assertIn(backend.get_value(model.optimizer.dynamic_counter), (0, 1)) + + # Test optimizer attributes and type + self.assertEqual(model.optimizer.initial_scale, 1.) + self.assertEqual(model.optimizer.dynamic_growth_steps, 2.) + self.assertEqual(type(model.optimizer), + loss_scale_optimizer.LossScaleOptimizer) if __name__ == '__main__': diff --git a/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py b/tensorflow/python/keras/mixed_precision/layer_correctness_test.py similarity index 98% rename from tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py rename to tensorflow/python/keras/mixed_precision/layer_correctness_test.py index e049b590ddd..bbccc8721cd 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py +++ b/tensorflow/python/keras/mixed_precision/layer_correctness_test.py @@ -42,7 +42,7 @@ from tensorflow.python.keras.layers import pooling from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 from tensorflow.python.keras.layers import wrappers -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.platform import test @@ -159,6 +159,10 @@ class LayerCorrectnessTest(keras_parameterized.TestCase): input_data: A Numpy array with the data of the input. If None, input data will be randomly generated """ + + if f32_layer_fn == convolutional.ZeroPadding2D and \ + test.is_built_with_rocm(): + return if isinstance(input_shape[0], int): input_shapes = [input_shape] else: diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py b/tensorflow/python/keras/mixed_precision/loss_scale.py similarity index 100% rename from tensorflow/python/keras/mixed_precision/experimental/loss_scale.py rename to tensorflow/python/keras/mixed_precision/loss_scale.py diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py b/tensorflow/python/keras/mixed_precision/loss_scale_benchmark.py similarity index 98% rename from tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py rename to tensorflow/python/keras/mixed_precision/loss_scale_benchmark.py index 95fcd1168d1..d468326e1ad 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_benchmark.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_benchmark.py @@ -25,7 +25,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import config -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py similarity index 50% rename from tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py rename to tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py index dd7bf6a682d..f1ca255133e 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py @@ -21,19 +21,26 @@ from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.keras import backend from tensorflow.python.keras import optimizers -from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module +from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging +from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export @@ -173,7 +180,220 @@ class _DelegatingTrackableMixin(object): # pylint: enable=protected-access -@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') +def _is_all_finite(grads): + """Returns a scalar boolean tensor indicating if all gradients are finite.""" + is_finite_per_grad = [ + math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None + ] + return math_ops.reduce_all(is_finite_per_grad) + + +def _op_in_graph_mode(tensor): + """Returns the tensor's op in graph mode, or the tensor in eager mode. + + This is useful because sometimes an op is needed in graph mode instead of a + tensor. In eager mode, there are no ops. + + Args: + tensor: A tensor. + + Returns: + The tensor's op in graph mode. The tensor in eager mode. + """ + if context.executing_eagerly(): + return tensor + return tensor.op + + +def _assign_if_finite(var, value): + """Assigns a value to a variable if the value is finite.""" + return control_flow_ops.cond( + math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), + control_flow_ops.no_op) + + +class _DynamicLossScaleState(trackable.Trackable): + """The state of a dynamic loss scale.""" + + def __init__(self, + initial_loss_scale, + growth_steps, + multiplier): + """Creates the dynamic loss scale.""" + super(_DynamicLossScaleState, self).__init__() + self._initial_loss_scale = float(initial_loss_scale) + self._growth_steps = int(growth_steps) + self._multiplier = float(multiplier) + + self._weights = {} + self._current_loss_scale = self._add_weight( + name='current_loss_scale', + dtype=dtypes.float32, + initial_value=self._initial_loss_scale) + # The number of consecutive steps with finite gradients since the last + # nonfinite gradient or change in loss scale. The name is 'good_steps' for + # backwards compatibility with older checkpoints. + self._counter = self._add_weight( + name='good_steps', dtype=dtypes.int64, initial_value=0) + + def _add_weight(self, name, initial_value, dtype=None): + """Adds a weight to this loss scale. + + Args: + name: Variable name. + initial_value: The variable's initial value. + dtype: The type of the variable. + + Returns: + A variable. + + Raises: + RuntimeError: If a weight with `name` has already been added. + """ + variable = variable_scope.variable( + initial_value=initial_value, + name=name, + dtype=dtype, + trainable=False, + use_resource=True, + synchronization=variables.VariableSynchronization.AUTO, + # Set aggregation to NONE, as loss scaling variables should never be + # aggregated. + aggregation=variables.VariableAggregation.NONE) + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + + key = (name, graph_key) + self._weights[key] = variable + self._handle_deferred_dependencies(name=name, trackable=variable) + backend.track_variable(variable) + return variable + + @property + def _checkpoint_dependencies(self): + """From Trackable. Gather graph-specific weights to save.""" + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + weights = [] + for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): + if g == graph_key: + weights.append(trackable.TrackableReference(name=name, ref=v)) + return (super(_DynamicLossScaleState, self)._checkpoint_dependencies + + weights) + + def _lookup_dependency(self, name): + """From Trackable. Find a weight in the current graph.""" + unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) + if unconditional is not None: + return unconditional + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + return self._weights.get((name, graph_key), None) + + @property + def initial_loss_scale(self): + return self._initial_loss_scale + + @property + def growth_steps(self): + return self._growth_steps + + @property + def multiplier(self): + return self._multiplier + + @property + def current_loss_scale(self): + """Returns the current loss scale as a float32 `tf.Variable`.""" + return self._current_loss_scale + + @property + def counter(self): + """Returns the counter as a float32 `tf.Variable`.""" + return self._counter + + def __call__(self): + """Returns the current loss scale as a scalar `float32` tensor.""" + return ops.convert_to_tensor(self._current_loss_scale) + + def update(self, grads): + """Updates the value of the loss scale. + + Args: + grads: A nested structure of unscaled gradients, each which is the + gradient of the loss with respect to a weight. + + Returns: + update_op: In eager mode, None. In graph mode, an op to update the loss + scale. + should_apply_gradients: Either a bool or a scalar boolean tensor. If + False, the caller should skip applying `grads` to the variables this + step. + """ + grads = nest.flatten(grads) + if distribution_strategy_context.has_strategy(): + distribution = distribution_strategy_context.get_strategy() + + def get_is_finite(grads): + is_finite = _is_all_finite(grads) + # We cast to float, because we cannot reduce booleans with + # DistributionStrategy. + return math_ops.cast(is_finite, dtypes.float32) + + is_finite_float = distribution.extended.call_for_each_replica( + get_is_finite, args=(grads,)) + reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, + is_finite_float, axis=None) + is_finite = math_ops.equal(reduced_is_finite_float, + distribution.num_replicas_in_sync) + else: + is_finite = _is_all_finite(grads) + + def update_if_finite_grads(): + """Update assuming the gradients are finite.""" + + def incr_loss_scale(): + new_loss_scale = self.current_loss_scale * self.multiplier + return control_flow_ops.group( + _assign_if_finite(self.current_loss_scale, new_loss_scale), + self.counter.assign(0)) + + return control_flow_ops.cond( + self.counter + 1 >= self.growth_steps, + incr_loss_scale, + lambda: _op_in_graph_mode(self.counter.assign_add(1))) + + def update_if_not_finite_grads(): + """Update assuming the gradients are nonfinite.""" + + new_loss_scale = math_ops.maximum( + self.current_loss_scale / self.multiplier, 1) + return control_flow_ops.group( + self.counter.assign(0), + self.current_loss_scale.assign(new_loss_scale)) + + update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, + update_if_not_finite_grads) + should_apply_gradients = is_finite + return update_op, should_apply_gradients + + +# See LossScaleOptimizer docstring for why this is so big +_DEFAULT_INITIAL_SCALE = 2 ** 15 +_DEFAULT_GROWTH_STEPS = 2000 + + +# pylint: disable=g-classes-have-attributes +@keras_export('keras.mixed_precision.LossScaleOptimizer') class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): """An optimizer that applies loss scaling. @@ -194,19 +414,21 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): applied. The loss scale can either be a fixed constant, chosen by the user, or be - dynamically determined. Dynamically determining the loss scale is convenient - as a loss scale does not have to be explicitly chosen. However it reduces - performance. + dynamically determined. Using a dynamic loss scale is highly recommend and is + the default behavior, as choosing a specific fixed loss scale is difficult. + Every step, the dynamic loss scale is potentially updated to a new value. + Dynamic loss scaling sometimes causes the loss scale to be too high and cause + the gradients to overflow, in which case gradients are not applied to + variables that step. - This optimizer wraps another optimizer and applies loss scaling to it via a - `LossScale`. Loss scaling is applied whenever gradients are - computed, either through `minimize()` or `get_gradients()`. The loss scale is - updated via `LossScale.update()` whenever gradients are applied, either - through `minimize()` or `apply_gradients()`. For example: + `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. + Loss scaling is applied whenever gradients are computed, either through + `minimize()` or `get_gradients()`. If dynamic, the loss scale is updated + whenever gradients are applied, either through `minimize()` or + `apply_gradients()`. For example: >>> opt = tf.keras.optimizers.SGD(0.25) - >>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, - ... "dynamic") + >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) >>> var = tf.Variable(1.) >>> loss_fn = lambda: var ** 2 >>> # 'minimize' applies loss scaling to the loss and updates the loss sale. @@ -230,12 +452,43 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): >>> var.numpy() 0.25 + Args: + inner_optimizer: The Optimizer instance to wrap. + dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to + True. If True, the loss scale will be dynamically updated over time using + an algorithm that keeps the loss scale at approximately its optimal value. + If False, a single fixed loss scale is used and `initial_scale` must be + specified, which is used as the loss scale. Recommended to keep as True, + as choosing a fixed loss scale can be tricky. Currently, there is a small + performance overhead to dynamic loss scaling compared to fixed loss + scaling. + initial_scale: The initial loss scale. If `dynamic` is True, this defaults + to 2 ** 15. If `dynamic` is False, this must be specified and acts as the + sole loss scale, as the loss scale does not change over time. When dynamic + loss scaling is used, is better for this to be a very high number, because + a loss scale that is too high gets lowered far more quickly than a loss + scale that is too low gets raised. + dynamic_growth_steps: With dynamic loss scaling, every + `dynamic_growth_steps` steps with finite gradients, the loss scale is + doubled. Defaults to 2000. If a nonfinite gradient is encountered, the + count is reset back to zero, gradients are skipped that step, and the loss + scale is halved. The count can be queried with + `LossScaleOptimizer.dynamic_counter`. This argument can only be specified + if `dynamic` is True. + + To use a fixed loss scale instead of dynamic loss scale, pass `dynamic=False` + and pass the loss scale to `initial_scale`. For example: + + >>> opt = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=1024) + >>> opt.loss_scale.numpy() + 1024. + Hyperparameters can be accessed and set on the LossScaleOptimizer, which will be delegated to the wrapped optimizer. >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) - >>> lso = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, - ... "dynamic") + >>> lso = tf.keras.mixed_precision.LossScaleOptimizer(opt) >>> opt.beta_1 0.8 >>> lso.beta_1 # Equivalent to `opt.beta_1` @@ -268,49 +521,103 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): _HAS_AGGREGATE_GRAD = True - def __init__(self, optimizer, loss_scale): - """Initializes this loss scale optimizer. - - Args: - optimizer: The Optimizer instance to wrap. - loss_scale: The loss scale to scale the loss and gradients. This can - either be an int/float to use a fixed loss scale, the string "dynamic" - to use dynamic loss scaling, or an instance of a LossScale. The string - "dynamic" equivalent to passing `DynamicLossScale()`, and passing an - int/float is equivalent to passing a FixedLossScale with the given loss - scale. - """ - if not isinstance(optimizer, optimizer_v2.OptimizerV2): - raise ValueError('"optimizer" must be an instance of OptimizerV2, but ' - 'got: %s' % optimizer) + def __init__(self, inner_optimizer, dynamic=True, initial_scale=None, + dynamic_growth_steps=None): + if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): + raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, ' + 'but got: %s' % inner_optimizer) + if not isinstance(dynamic, bool): + # Catch errors if a user incorrectly passes a string or float to the + # second argument argument, as this is commonly done for + # LossScaleOptimizerV1. + raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' + 'be a bool, but got: %r' % (dynamic,)) self._raise_if_strategy_unsupported() - - self._optimizer = optimizer - self._loss_scale = keras_loss_scale_module.get(loss_scale) - if self._loss_scale is None: - raise ValueError('loss_scale cannot be None.') + self._optimizer = inner_optimizer # We don't call super().__init__, since we do not want to call OptimizerV2's # constructor. _DelegatingTrackableMixin.__init__(self, self._optimizer) - for weight in self._loss_scale._weights.values(): # pylint: disable=protected-access - # We cannot call `track_variable` in the LossScale class itself, because a - # file outside of Keras cannot depend on a Keras file. Calling it here - # instead is OK, because a variable only needs to be tracked if used with - # a Keras class, and the only way to use LossScale with a Keras class is - # through the LossScaleOptimizer. - backend.track_variable(weight) - self._track_trackable(self._loss_scale, 'loss_scale') + if dynamic: + if initial_scale is None: + initial_scale = _DEFAULT_INITIAL_SCALE + if dynamic_growth_steps is None: + dynamic_growth_steps = _DEFAULT_GROWTH_STEPS + self._loss_scale = _DynamicLossScaleState( + initial_scale, dynamic_growth_steps, multiplier=2) + self._track_trackable(self._loss_scale, 'loss_scale') + else: + if initial_scale is None: + raise ValueError('"initial_scale" must be specified if "dynamic" is ' + 'False') + self._loss_scale = float(initial_scale) + if dynamic_growth_steps is not None: + raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' + 'is False, but got: %s' % (dynamic_growth_steps,)) # To support restoring TensorFlow 2.2 checkpoints. self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 'base_optimizer') + @property + def dynamic(self): + return isinstance(self._loss_scale, _DynamicLossScaleState) + @property def loss_scale(self): - """The `LossScale` instance associated with this optimizer.""" - return self._loss_scale + """The current loss scale as a float32 scalar tensor.""" + if isinstance(self._loss_scale, _DynamicLossScaleState): + return ops.convert_to_tensor(self._loss_scale.current_loss_scale) + else: + return ops.convert_to_tensor(self._loss_scale) + + @property + def dynamic_counter(self): + """The number of steps since the loss scale was last increased or decreased. + + This is None if `LossScaleOptimizer.dynamic` is False. + + The counter is incremented every step. Once it reaches + `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled + and the counter will be reset back to zero. If nonfinite gradients are + encountered, the loss scale will be halved and the counter will be reset + back to zero. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.counter + else: + return None + + @property + def initial_scale(self): + """The initial loss scale. + + This is None if `LossScaleOptimizer.dynamic` is False. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.initial_loss_scale + else: + return self._loss_scale + + @property + def dynamic_growth_steps(self): + """The number of steps it takes to increase the loss scale. + + This is None if `LossScaleOptimizer.dynamic` is False. + + Every `dynamic_growth_steps` consecutive steps with finite gradients, the + loss scale is increased. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.growth_steps + else: + return None + + @property + def inner_optimizer(self): + """The optimizer that this LossScaleOptimizer is wrapping.""" + return self._optimizer def get_scaled_loss(self, loss): """Scales the loss by the loss scale. @@ -322,7 +629,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): scaling is automatically applied and this method is unneeded. If this method is called, `get_unscaled_gradients` should also be called. - See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for + See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an example. Args: @@ -330,16 +637,15 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): a tensor or a callable returning a tensor. Returns: - `loss` multiplied by `LossScaleOptimizer.loss_scale()`. + `loss` multiplied by `LossScaleOptimizer.loss_scale`. """ - loss_scale = self._loss_scale() if callable(loss): def new_loss(): loss_val = loss() - return loss_val * math_ops.cast(loss_scale, loss_val.dtype) + return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) return new_loss else: - return loss * math_ops.cast(loss_scale, loss.dtype) + return loss * math_ops.cast(self.loss_scale, loss.dtype) def get_unscaled_gradients(self, grads): """Unscales the gradients by the loss scale. @@ -351,7 +657,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): scaling is automatically applied and this method is unneeded. If this method is called, `get_scaled_loss` should also be called. See - the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an + the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an example. Args: @@ -360,10 +666,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): Returns: A new list the same size as `grads`, where every non-None value in `grads` - is divided by `LossScaleOptimizer.loss_scale()`. + is divided by `LossScaleOptimizer.loss_scale`. """ - loss_scale = self._loss_scale() - loss_scale_reciprocal = 1. / loss_scale + loss_scale_reciprocal = 1. / self.loss_scale return [ _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None for g in grads @@ -379,9 +684,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): grad_loss, tape=tape) grads = [g for g, _ in grads_and_vars] - variables = [v for _, v in grads_and_vars] + weights = [v for _, v in grads_and_vars] unscaled_grads = self.get_unscaled_gradients(grads) - return list(zip(unscaled_grads, variables)) + return list(zip(unscaled_grads, weights)) def get_gradients(self, loss, params): loss = self.get_scaled_loss(loss) @@ -409,7 +714,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name, experimental_aggregate_gradients): grads = [g for g, _ in grads_and_vars] - loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) + if isinstance(self._loss_scale, _DynamicLossScaleState): + loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) + else: + loss_scale_update_op = control_flow_ops.no_op() + should_apply_grads = True def apply_fn(): # We do not want DistributionStrategy to unwrap any MirroredVariables in @@ -447,19 +756,41 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def get_config(self): serialized_optimizer = optimizers.serialize(self._optimizer) - serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale) return { - 'optimizer': serialized_optimizer, - 'loss_scale': serialized_loss_scale, + 'inner_optimizer': serialized_optimizer, + 'dynamic': self.dynamic, + 'initial_scale': self.initial_scale, + 'dynamic_growth_steps': self.dynamic_growth_steps, } @classmethod def from_config(cls, config, custom_objects=None): config = config.copy() # Make a copy, since we mutate config - config['optimizer'] = optimizers.deserialize( - config['optimizer'], custom_objects=custom_objects) - config['loss_scale'] = keras_loss_scale_module.deserialize( - config['loss_scale'], custom_objects=custom_objects) + if 'loss_scale' in config: + # If loss_scale is in config, we assume we are deserializing a + # LossScaleOptimizer from TF 2.3 or below. We convert the config so it + # can be deserialized in the current LossScaleOptimizer. + loss_scale = keras_loss_scale_module.deserialize( + config.pop('loss_scale')) + if isinstance(loss_scale, loss_scale_module.FixedLossScale): + config['dynamic'] = False + config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access + elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): + config['dynamic'] = True + config['initial_scale'] = loss_scale.initial_loss_scale + config['dynamic_growth_steps'] = loss_scale.increment_period + if loss_scale.multiplier != 2: + raise ValueError('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: %s' % (loss_scale,)) + else: + raise ValueError( + 'Serialized LossScaleOptimizers with a LossScale that is neither a ' + 'FixedLossScale nor a DynamicLossScale can no longer be ' + 'deserialized') + config['inner_optimizer'] = config.pop('optimizer') + config['inner_optimizer'] = optimizers.deserialize( + config['inner_optimizer'], custom_objects=custom_objects) return cls(**config) def _raise_if_strategy_unsupported(self): @@ -601,15 +932,164 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): # both self._compute_gradients() and self.apply_gradients(), and both need # to have the LossScaleOptimizer version called. - # TODO(reedwm): Maybe merge this class's functionality into OptimizerV2. - # TODO(reedwm): Maybe throw an error if mixed precision is used without this # optimizer being used. - # Trackable delegations: Delegate all Trackable methods to the wrapped - # optimizer. This is so the checkpoint format for a LossScaleOptimizer is - # identical to the checkpoint format for a normal optimizer, except the loss - # scale is stored in the checkpoint. + +@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') +class LossScaleOptimizerV1(LossScaleOptimizer): + """An deprecated optimizer that applies loss scaling. + + Warning: This class is deprecated and will be removed in TensorFlow 2.5. + Please use the non-experimental class + `tf.keras.mixed_precision.LossScaleOptimizer` instead. + + This class is identical to the non-experimental + `keras.mixed_precision.LossScaleOptimizer` except its constructor takes + different arguments. For this class (the experimental version), the + constructor takes a `loss_scale` argument. For the non-experimental class, + the constructor encodes the loss scaling information in multiple arguments. + Note that unlike this class, the non-experimental class does not accept a + `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. + + If you currently use this class, you should switch to the non-experimental + `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several + examples of converting the use of the experimental class to the equivalent + non-experimental class. + + >>> # In all of the the examples below, `opt1` and `opt2` are identical + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale='dynamic') + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD()) + >>> assert opt1.get_config() == opt2.get_config() + + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale=123) + >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 + >>> # refers to the initial loss scale, which is the single fixed loss scale + >>> # when dynamic=False. + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) + >>> assert opt1.get_config() == opt2.get_config() + + >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( + ... initial_loss_scale=2048, increment_period=500) + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), initial_scale=2048, + ... dynamic_growth_steps=500) + >>> assert opt1.get_config() == opt2.get_config() + + Args: + optimizer: The Optimizer instance to wrap. + loss_scale: The loss scale to scale the loss and gradients. This can + either be an int/float to use a fixed loss scale, the string "dynamic" + to use dynamic loss scaling, or an instance of a LossScale. The string + "dynamic" equivalent to passing `DynamicLossScale()`, and passing an + int/float is equivalent to passing a FixedLossScale with the given loss + scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must + be 2 (the default). + """ + + def __init__(self, optimizer, loss_scale): + warn_msg_prefix = ( + 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' + 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' + 'instead. ') + + if isinstance(loss_scale, dict): + loss_scale = keras_loss_scale_module.deserialize(loss_scale) + + if isinstance(loss_scale, (int, float)): + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) + super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, + initial_scale=loss_scale) + elif isinstance(loss_scale, loss_scale_module.FixedLossScale): + ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt, dynamic=False, initial_scale={})'.format(ls_val)) + super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, + initial_scale=ls_val) + elif loss_scale == 'dynamic': + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt)') + super(LossScaleOptimizerV1, self).__init__(optimizer) + elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): + kwargs = {} + extra_arguments = '' + if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: + kwargs['initial_scale'] = loss_scale.initial_loss_scale + extra_arguments += (', initial_scale=%s' % + loss_scale.initial_loss_scale) + if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: + kwargs['dynamic_growth_steps'] = loss_scale.increment_period + extra_arguments += (', dynamic_growth_steps=%s' % + loss_scale.increment_period) + if loss_scale.multiplier != 2: + raise ValueError('When passing a DynamicLossScale to "loss_scale", ' + 'DynamicLossScale.multiplier must be 2. Got: %s' + % (loss_scale,)) + tf_logging.warn( + warn_msg_prefix + + 'Note that the non-experimental LossScaleOptimizer does not take a ' + 'DynamicLossScale but instead takes the dynamic configuration ' + 'directly in the constructor. For example:\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt{})\n'.format(extra_arguments)) + super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) + elif isinstance(loss_scale, loss_scale_module.LossScale): + raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' + 'DynamicLossScale is no longer supported. Got: {}' + .format(loss_scale)) + else: + raise ValueError('Invalid value passed to loss_scale. loss_scale ' + 'must be the string "dynamic" (recommended), an int, ' + 'a float, a FixedLossScale, or a DynamicLossScale. Got ' + 'value: {}'.format(loss_scale)) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() # Make a copy, since we mutate config + + # If loss_scale is in config, we assume we are deserializing a + # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are + # deserializing a LossScaleOptimizer from TF 2.4 or above. + if 'loss_scale' in config: + config['loss_scale'] = keras_loss_scale_module.deserialize( + config['loss_scale']) + if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) + and config['loss_scale'].multiplier != 2): + raise ValueError('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: %s' % (config['loss_scale'],)) + config['optimizer'] = optimizers.deserialize( + config['optimizer'], custom_objects=custom_objects) + return cls(**config) + + # We convert the config, as generated by LossScaleOptimizer.get_config, to a + # version that can be passed to LossScaleOptimizerV1.__init__ + if config['dynamic']: + config['loss_scale'] = loss_scale_module.DynamicLossScale( + config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) + else: + config['loss_scale'] = loss_scale_module.FixedLossScale( + config['initial_scale']) + + del config['dynamic'] + del config['initial_scale'] + del config['dynamic_growth_steps'] + config['optimizer'] = optimizers.deserialize( + config.pop('inner_optimizer'), custom_objects=custom_objects) + return cls(**config) class FakeOptimizerForRestoration(trackable.Trackable): @@ -654,7 +1134,7 @@ class FakeOptimizerForRestoration(trackable.Trackable): # pylint: disable=protected-access mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, - LossScaleOptimizer) + LossScaleOptimizerV1) def _multiply_gradient(gradient, scale): diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer_test.py similarity index 57% rename from tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py rename to tensorflow/python/keras/mixed_precision/loss_scale_optimizer_test.py index fe3a237ef83..e9f375303a6 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer_test.py @@ -32,15 +32,15 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import optimizers -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer -from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision import test_util as mp_test_util from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.experimental import loss_scale as tf_loss_scale_module from tensorflow.python.training.tracking import util as trackable_utils # Disable not-callable lint error, as the linter is unable to detect that @@ -93,7 +93,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([5.0]) opt = gradient_descent.SGD(2.0) loss_scale = 10. - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) + self.assertEqual(self.evaluate(opt.loss_scale), loss_scale) + self.assertIsInstance(opt.loss_scale, ops.Tensor) # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale # / strategy.num_replicas_in_sync will not be exact, which could lead to # assertion failures due to rounding issues. @@ -112,7 +115,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = gradient_descent.SGD(1.0) loss_scale = 10. - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( loss_scale) loss = grad_check_fn(var) @@ -122,9 +126,18 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # mp_test_util.create_identity_with_grad_check_fn added an assertion op. self.evaluate(run_op) + def testDynamicAttrsWithFixedLossScale(self): + opt = gradient_descent.SGD() + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2.) + self.assertFalse(opt.dynamic) + self.assertIsNone(opt.dynamic_counter) + self.assertIsNone(opt.dynamic_growth_steps) + def testGetScaledLoss(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2.) loss = ops.convert_to_tensor_v2_with_dispatch(5.) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss))) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)())) @@ -134,7 +147,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): def testGetUnscaledGradients(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2) scaled_grads = [ ops.convert_to_tensor_v2_with_dispatch(3.), None, ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16') @@ -145,7 +159,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): def testGetUnscaledSparseGradients(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2) sparse_scaled_grad = ops.IndexedSlices( ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]), ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'), @@ -165,12 +180,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - self.assertEqual( - loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) + self.assertEqual(opt.initial_scale, 2.) + self.assertIsInstance(opt.initial_scale, float) + self.assertEqual(opt.dynamic_growth_steps, 1) + self.assertIsInstance(opt.dynamic_growth_steps, int) + self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0) run_fn = self._run_fn_with_grad_check(strategy, var, opt, expected_gradient) run_op = strategy.experimental_run(run_fn) @@ -189,6 +206,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # 1. self.assertAllClose([1.], self.evaluate(var)) + def testDynamicLossScaleDefaultValues(self): + opt = gradient_descent.SGD() + opt = loss_scale_optimizer.LossScaleOptimizer(opt) + self.assertEqual(opt.initial_scale, 2 ** 15) + self.assertEqual(opt.dynamic_growth_steps, 2000) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15) + # pylint: disable=cell-var-from-loop @parameterized.named_parameters(*TESTCASES) def testClipping(self, strategy_fn): @@ -198,12 +223,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(), self.subTest(clip_type=clip_type): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0}) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) self.assertEqual(getattr(opt, clip_type), 2.0) - self.assertEqual( - loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0) loss = lambda: var * 4 / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, var_list=[var]) @@ -215,7 +238,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # The gradient is 4 but is clipped to 2, so the variable will be # init_val - clipped_grad * lr == 5 - 2 * 2 == 1 self.assertAllClose([1.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), 4) + self.assertEqual(self.evaluate(opt.loss_scale), 4) # Test changing the clip amount and running again setattr(opt, clip_type, 3.0) @@ -224,7 +247,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # The gradient is 4 but is clipped to 3, so the variable will be # prev_var - clipped_grad * lr == 1 - 3 * 2 == -5 self.assertAllClose([-5.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), 8) + self.assertEqual(self.evaluate(opt.loss_scale), 8) # Test Inf gradients are still skipped instead of being clipped loss = lambda: var * float('Inf') @@ -232,7 +255,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): run_op = strategy.experimental_run(run_fn) self._run_if_in_graph_mode(run_op) self.assertAllClose([-5.], self.evaluate(var)) # Var does not change - self.assertEqual(self.evaluate(opt.loss_scale()), 4) + self.assertEqual(self.evaluate(opt.loss_scale), 4) # pylint: enable=cell-var-from-loop @parameterized.named_parameters(*TESTCASES) @@ -240,9 +263,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy_fn().scope() as strategy: var = variables.Variable([1.0, 2.0]) opt = gradient_descent.SGD(1.0) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) # Test optimizer with finite gradients loss = lambda: var * 2.0 / strategy.num_replicas_in_sync @@ -253,7 +275,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Gradient is 2, so variable will have 2 subtracted from it self.assertAllClose([-1.0, 0.0], self.evaluate(var)) # Loss scale has doubled from 2 to 4 - self.assertEqual(4., self.evaluate(opt.loss_scale())) + self.assertEqual(4., self.evaluate(opt.loss_scale)) # Test optimizer with NaN gradients loss = lambda: var * float('NaN') @@ -263,7 +285,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Variable should not change from before, due to NaN gradients. self.assertAllClose(self.evaluate(var), [-1.0, 0.0]) # Loss scale should half due to NaN gradients. - self.assertEqual(2., self.evaluate(opt.loss_scale())) + self.assertEqual(2., self.evaluate(opt.loss_scale)) @parameterized.named_parameters(*TESTCASES) def testDynamicLossScaleWithFloat16Loss(self, strategy_fn): @@ -272,9 +294,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) def loss(): return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16') @@ -297,11 +318,9 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([1.0, 2.0]) # An SGD optimizer with momentum has slot variables. opt = gradient_descent.SGD(1.0, momentum=1.) - initial_loss_scale = 2. - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=1, - multiplier=4) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + initial_scale = 2. + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=initial_scale, dynamic_growth_steps=1) loss = lambda: var / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, var_list=[var]) run_op = strategy.experimental_run(run_fn) @@ -312,7 +331,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # variable is subtracted by the accumulator, so the variable is subtracted # by 1. self.assertAllClose([0.0, 1.0], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4) + self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 2) run_op = strategy.experimental_run(run_fn) self._run_if_in_graph_mode(run_op) @@ -321,14 +340,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # variable is subtracted by the accumulator, so the variable is subtracted # by 2. self.assertAllClose([-2., -1.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), - initial_loss_scale * 16) + self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 4) self.assertEqual(opt.get_slot_names(), ['momentum']) def testIterations(self): opt = gradient_descent.SGD(2.0) - lso = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=10.) + lso = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=10.) lso.iterations = 7 self.assertEqual(lso.iterations, 7) self.assertEqual(opt.iterations, 7) @@ -338,7 +357,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy_fn().scope() as strategy: # Test iterations is incremented in opt.minimize. opt = gradient_descent.SGD(1.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale='dynamic') + opt = loss_scale_optimizer.LossScaleOptimizer(opt) var = variables.Variable([5.0]) loss = lambda: var * 2.0 / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, [var]) @@ -361,11 +380,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with self.test_session(): var = variables.Variable([1.0]) opt = gradient_descent.SGD(1.0) - initial_loss_scale = 2. - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=1, - multiplier=4) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=1) run_op = opt.minimize(lambda: var * 2, [var]) self.evaluate(variables.global_variables_initializer()) self._run_if_in_graph_mode(run_op) @@ -377,15 +393,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): opt.set_weights([np.array(2.)]) self.assertEqual(self.evaluate(opt.variables()[0]), 2) - def testPassingNoneToLossScale(self): - opt = gradient_descent.SGD() - with self.assertRaisesRegex(ValueError, r'loss_scale cannot be None'): - loss_scale_optimizer.LossScaleOptimizer(opt, None) - def testHyperParametersExposed(self): with self.cached_session(): opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9) - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) # Force hyperparameters to be created opt.lr # pylint: disable=pointless-statement self.evaluate(variables.global_variables_initializer()) @@ -420,13 +431,13 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self._set_hyper('loss_scale', 123.) opt = MyOpt() - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) with self.assertRaises(AttributeError): - lso.loss_scale = loss_scale_module.FixedLossScale(2.) + lso.loss_scale = 2. def testArbitraryAttributesNotExposed(self): opt = gradient_descent.SGD() - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) self.assertFalse(opt.nesterov) with self.assertRaisesRegex( AttributeError, @@ -438,15 +449,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertFalse(opt.nesterov) def testDir(self): - lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), - 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) dir_result = dir(lso) self.assertIn('learning_rate', dir_result) # Hyperparameter self.assertIn('lr', dir_result) # Hyperparameter self.assertIn('minimize', dir_result) # Attribute self.assertIn('loss_scale', dir_result) # Attribute self.assertNotIn('nesterov', dir_result) # Attribute on inner optimizer - self.assertIn('nesterov', dir(lso._optimizer)) + self.assertIn('nesterov', dir(lso.inner_optimizer)) def testApplyGradientsGetsUnwrappedTensors(self): # Tests that gradients passed to apply_gradients are not wrapped in a @@ -471,11 +481,125 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with create_mirrored_strategy().scope() as strategy: var = variables.Variable([5.0]) opt = MyOptimizer(learning_rate=1.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=1) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=1) loss = lambda: var * 2.0 run_fn = lambda: opt.minimize(loss, [var]) strategy.experimental_run(run_fn) + @parameterized.named_parameters(*TESTCASES) + def testV1Optimizer(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + with strategy.scope(): + # Test FixedLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale=2) + self.assertIsInstance(opt.loss_scale, ops.Tensor) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2) + self.assertEqual(opt.initial_scale, 2) + self.assertIsNone(opt.dynamic_growth_steps) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 2 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Test DynamicLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 'dynamic') + self.assertEqual(opt.initial_scale, 2 ** 15) + self.assertEqual(opt.dynamic_growth_steps, 2000) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + loss = lambda: var * float('NaN') + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertAllClose([5.], self.evaluate(var)) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 14) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + @parameterized.named_parameters(*TESTCASES) + def testPassingV1LossScale(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + with strategy.scope(): + # Test FixedLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + loss_scale = tf_loss_scale_module.FixedLossScale(2.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + self.assertIsInstance(opt.loss_scale, ops.Tensor) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 2 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Test DynamicLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=4, increment_period=1, multiplier=2) + loss_scale._current_loss_scale.assign(2) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + self.assertEqual(opt.initial_scale, 4) + self.assertEqual(opt.dynamic_growth_steps, 1) + self.evaluate(variables.global_variables_initializer()) + # Current loss scale is not copied so loss scale is reinitialized to 4 + self.assertEqual(self.evaluate(opt.loss_scale), 4) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 4 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertAllClose([3.], self.evaluate(var)) + + def testPassingV1LossScaleErrors(self): + opt = gradient_descent.SGD() + loss_scale = tf_loss_scale_module.DynamicLossScale(multiplier=4) + with self.assertRaisesRegex( + ValueError, 'When passing a DynamicLossScale to "loss_scale", ' + 'DynamicLossScale.multiplier must be 2. Got: ' + 'DynamicLossScale'): + loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + + class MyLossScale(tf_loss_scale_module.LossScale): + + def __call__(self): + return 1. + + def update(self, grads): + return None, True + + def get_config(self): + return {} + + with self.assertRaisesRegex( + TypeError, 'Passing a LossScale that is not a FixedLossScale or a ' + 'DynamicLossScale is no longer supported. Got:'): + loss_scale_optimizer.LossScaleOptimizerV1(opt, MyLossScale()) + @parameterized.named_parameters({ 'testcase_name': 'SaveAndRestoreBase', 'strategy_fn': default_strategy_fn, @@ -529,10 +653,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if save_with_ls: - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., - multiplier=2.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var]) opt_op = strategy.experimental_run(run_fn) self.evaluate(variables.global_variables_initializer()) @@ -541,8 +663,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Assert values. self.assertEqual(self.evaluate(var), 1.) if save_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) slot_var = opt.get_slot(var, 'momentum') self.assertEqual(self.evaluate(slot_var).item(), -1) self.assertEqual(self.evaluate(opt.iterations), 1) @@ -560,10 +682,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if restore_with_ls: - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., - multiplier=2.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) # Restore new model. checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var) @@ -578,11 +698,11 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): self.assertEqual(self.evaluate(var), 1.) if save_with_ls and restore_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) elif restore_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 0) self.assertEqual(self.evaluate(opt.iterations), 1) # Run the model again. @@ -611,30 +731,180 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(var), 1) self.assertEqual(self.evaluate(slot_var).item(), -1) - def testGetConfig(self): + @combinations.generate(combinations.combine( + get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2'])) + def testGetConfigFixed(self, get_config, from_config): + # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the + # LossScaleOptimizer from TF 2.3. Then restore the config into a + # LossScaleOptimizerV1 or LossScaleOptimizer opt = gradient_descent.SGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - config = opt.get_config() - opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + if get_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 2) + config = opt.get_config() + elif get_config == 'v2': + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=2) + config = opt.get_config() + else: + self.assertEqual(get_config, 'tf2_3') + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'FixedLossScale', + 'config': {'loss_scale_value': 2.0} + }, + } + + if from_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + else: + self.assertEqual(from_config, 'v2') + opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + # Force hyperparameters to be created opt.lr # pylint: disable=pointless-statement self.evaluate(variables.global_variables_initializer()) + # Test attributes on the optimizer self.assertEqual(self.evaluate(opt.lr), 2.) - self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) + self.assertEqual(self.evaluate(opt.inner_optimizer.lr), 2.) + self.assertEqual(self.evaluate(opt.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.initial_scale, 2.) + self.assertIsNone(opt.dynamic_growth_steps) + self.assertIsNone(opt.dynamic_counter) + self.assertFalse(opt.dynamic) - def testSerializationWithBuiltInOptimizer(self): + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + + @combinations.generate(combinations.combine( + get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2'])) + def testGetConfigDynamic(self, get_config, from_config): + # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the + # LossScaleOptimizer from TF 2.3. Then restore the config into a + # LossScaleOptimizerV1 or LossScaleOptimizer opt = gradient_descent.SGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + if get_config == 'v1': + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=3) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + config = opt.get_config() + elif get_config == 'v2': + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=3) + config = opt.get_config() + else: + self.assertEqual(get_config, 'tf2_3') + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'DynamicLossScale', + 'config': { + 'initial_loss_scale': 2.0, + 'increment_period': 3, + 'multiplier': 2.0, + } + }, + } + + if from_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + else: + self.assertEqual(from_config, 'v2') + opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + + # Force hyperparameters to be created + opt.lr # pylint: disable=pointless-statement + self.evaluate(variables.global_variables_initializer()) + + # Test attributes on the optimizer + self.assertEqual(self.evaluate(opt.lr), 2.) + self.assertEqual(self.evaluate(opt.inner_optimizer.lr), 2.) + self.assertEqual(self.evaluate(opt.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.initial_scale, 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) + self.assertTrue(opt.dynamic) + + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) + + def test_from_config_with_invalid_multiplier(self): + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'DynamicLossScale', + 'config': { + 'initial_loss_scale': 2.0, + 'increment_period': 3, + 'multiplier': 4.0, + } + }, + } + + expected_error = ('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: DynamicLossScale\\(') + with self.assertRaisesRegex(ValueError, expected_error): + loss_scale_optimizer.LossScaleOptimizer.from_config(config) + with self.assertRaisesRegex(ValueError, expected_error): + loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + + @parameterized.named_parameters({ + 'testcase_name': 'V2', + 'use_v1': False, + }, { + 'testcase_name': 'V1', + 'use_v1': True, + },) + def testSerializationWithBuiltInOptimizer(self, use_v1): + opt = gradient_descent.SGD(2., momentum=0.5) + if use_v1: + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=2., increment_period=3.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + else: + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=3.) config = optimizers.serialize(opt) opt = optimizers.deserialize(config) # Force hyperparameters to be created @@ -642,10 +912,22 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(opt.lr), 2.) - self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) + self.assertEqual(self.evaluate(opt.inner_optimizer.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) + self.assertTrue(opt.dynamic, 4.) + # Deserializing a LossScaleOptimizer always always results in a V2 + # LossScaleOptimizer, even if serialized with a LossScaleOptimizerV1. + self.assertAllEqual(type(opt), loss_scale_optimizer.LossScaleOptimizer) + + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) def testSerializationWithCustomOptimizer(self): class MySGD(gradient_descent.SGD): @@ -655,10 +937,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.my_attribute = 123 opt = MySGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=3.) config = optimizers.serialize(opt) custom_objects = {'MySGD': MySGD} opt = optimizers.deserialize(config, custom_objects=custom_objects) @@ -667,11 +947,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(opt.lr), 2.) - self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) - self.assertEqual(opt._optimizer.my_attribute, 123) + self.assertEqual(self.evaluate(opt.inner_optimizer.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) + self.assertEqual(opt.inner_optimizer.my_attribute, 123) def testUnsupportedStrategy(self): strategy = central_storage_strategy.CentralStorageStrategy() @@ -680,8 +959,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): 'CentralStorageStrategy. Try using a different Strategy, e.g. a ' 'MirroredStrategy') with strategy.scope(), self.assertRaisesRegex(ValueError, expected_error): - loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) - opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) + loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) + opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) with strategy.scope(): var = variables.Variable(1.0) loss = lambda: var * 2.0 @@ -689,6 +968,24 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, expected_error): strategy.experimental_run(run_fn) + def testInvalidArgsWithFixedLossScale(self): + opt = gradient_descent.SGD() + with self.assertRaisesRegex( + ValueError, '"initial_scale" must be specified if "dynamic" is False'): + loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False) + with self.assertRaisesRegex( + ValueError, '"dynamic_growth_steps" must be None if "dynamic" is ' + 'False, but got: 2'): + loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=1, dynamic_growth_steps=2) + + def testDynamicMustBeBool(self): + opt = gradient_descent.SGD() + with self.assertRaisesRegex( + TypeError, '"dynamic" argument to LossScaleOptimizer.__init__ must be ' + "a bool, but got: 'dynamic'"): + loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py b/tensorflow/python/keras/mixed_precision/mixed_precision_graph_rewrite_test.py similarity index 91% rename from tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py rename to tensorflow/python/keras/mixed_precision/mixed_precision_graph_rewrite_test.py index d0fea573bd0..3fc9b9c455b 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py +++ b/tensorflow/python/keras/mixed_precision/mixed_precision_graph_rewrite_test.py @@ -24,8 +24,8 @@ from tensorflow.python.framework import config from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as loss_scale_optimizer_v2 +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.platform import test from tensorflow.python.training.experimental import mixed_precision @@ -65,13 +65,13 @@ class MixedPrecisionTest(keras_parameterized.TestCase): opt = gradient_descent_v2.SGD(1.0) opt = enable_mixed_precision_graph_rewrite(opt, 123.) self.assertIsInstance( - opt, loss_scale_optimizer_v2.LossScaleOptimizer) - self.assertEqual(self.evaluate(opt._loss_scale()), 123.) + opt, loss_scale_optimizer_v2.LossScaleOptimizerV1) + self.assertEqual(self.evaluate(opt.loss_scale), 123.) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_optimizer_errors(self): opt = gradient_descent_v2.SGD(1.0) - opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic') + opt = loss_scale_optimizer_v2.LossScaleOptimizerV1(opt, 'dynamic') with self.assertRaisesRegex( ValueError, '"opt" must not already be an instance of a ' 'LossScaleOptimizer.'): diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/policy.py similarity index 82% rename from tensorflow/python/keras/mixed_precision/experimental/policy.py rename to tensorflow/python/keras/mixed_precision/policy.py index 33f6562f796..dec172932d2 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/policy.py @@ -24,23 +24,23 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils -from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check -from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module +from tensorflow.python.keras.mixed_precision import device_compatibility_check +from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module from tensorflow.python.keras.utils import generic_utils from tensorflow.python.platform import tf_logging from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) +@keras_export('keras.mixed_precision.Policy', v1=[]) class Policy(object): """A dtype policy for a Keras layer. - A dtype policy determines dtype-related aspects of a layer, such as its - computation and variable dtypes. Each layer has a policy. Policies can be - passed to the `dtype` argument of layer constructors, or a global policy can - be set with `tf.keras.mixed_precision.experimental.set_policy`. A layer will - default to the global policy if no policy is passed to it's constructor. + A dtype policy determines a layer's computation and variable dtypes. Each + layer has a policy. Policies can be passed to the `dtype` argument of layer + constructors, or a global policy can be set with + `tf.keras.mixed_precision.experimental.set_policy`. A layer will default to + the global policy if no policy is passed to it's constructor. For many models, each layer's policy will have the same compute dtype and variable dtype, which will typically be float32. In this case, we refer to the @@ -56,24 +56,17 @@ class Policy(object): https://www.tensorflow.org/guide/keras/mixed_precision) for more information on how to use mixed precision. - Certain policies also have a `tf.mixed_precision.experimental.LossScale` - instance, which is used by `tf.keras.Model`s to performance loss scaling. Loss - scaling is a technique used with mixed precision to avoid numerical underflow - in float16 gradients. Loss scaling is only done by Models in `Model.fit`, - `Model.train_on_batch`, and similar methods. Layers which are not Models - ignore the loss scale. - Policies are constructed by passing a string to the constructor, e.g. - `tf.keras.mixed_precision.experimental.Policy('float32')`. The string - determines the compute and variable dtypes. It can be one of the following: + `tf.keras.mixed_precision.Policy('float32')`. The string determines the + compute and variable dtypes. It can be one of the following: - * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. No loss scaling is done by default. - * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or - bfloat16, while the variable dtype is float32. These policies are used for - mixed precision training. With 'mixed_float16', a dynamic loss scale is - used by default. 'mixed_bfloat16' does no loss scaling by default, as loss - scaling is unnecessary with bfloat16. + * Any dtype name, such as 'float32' or 'float64'. Both the variable and + compute dtypes will be that dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + `tf.keras.Model.compile` will wrap the optimizer with a + `tf.keras.mixed_precision.LossScaleOptimizer`. These policies are used for + mixed precision training. ### How to use mixed precision in a Keras model @@ -97,7 +90,7 @@ class Policy(object): Alternatively, the policy can be passed to individual layers instead of setting the global policy with `set_policy`: - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> model = tf.keras.models.Sequential([ ... tf.keras.layers.Input((100,)), ... tf.keras.layers.Dense(10, dtype=policy), @@ -110,7 +103,7 @@ class Policy(object): `Model.fit`, `Model.train_on_batch`, and other training methods. If no such method is used (e.g., a custom training loop is used) and `'mixed_float16'` is used, the loss scale must be manually applied. See - `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For + `tf.keras.mixed_precision.LossScaleOptimizer` for details. For `'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be manually applied. @@ -227,11 +220,12 @@ class Policy(object): ... def build(self, input_shape): ... self.x = self.add_weight('x') ... self.y = self.add_weight('y', experimental_autocast=False) - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> layer = MyLayer(dtype=policy) >>> layer.build((2, 2)) >>> layer.x - + >>> layer.y @@ -257,7 +251,7 @@ class Policy(object): ... def call(self, inputs): ... return tf.matmul(inputs, self.kernel) - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> layer = MyDense(dtype=policy) >>> x = np.random.rand(10, 10) >>> y = layer(x) @@ -288,29 +282,27 @@ class Policy(object): layer would only work if the inputs were float32. """ - def __init__(self, name, loss_scale='auto'): + def __init__(self, name): """Constructs the policy. - The `name` argument determines the compute and variable dtype, the default - loss scale, and has no additional effect on the Policy. The compute and - variable dtypes can only be specified through `name`, and cannot be + The `name` argument determines the compute and variable dtype. The compute + and variable dtypes can only be specified through `name`, and cannot be specified directly. + `name` is also used by `tf.keras.Model.compile`. If `name` is + `"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the + optimizer with a LossScaleOptimizer if it is not already a + LossScaleOptimizer. + Args: name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. With 'mixed_float16', - a dynamic loss scale is used. These policies are used for mixed - precision training. - loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which - uses a `FixedLossScale`), the string "dynamic" (which uses a - `DynamicLossScale`), or None (which uses no loss scale). Defaults to - `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then - use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only - `tf.keras.Model`s, not layers, use the loss scale, and it is only used - during `Model.fit`, `Model.train_on_batch`, and other similar methods. + `tf.keras.Model.compile` will wrap the optimizer with a + `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used + for mixed precision training. """ if isinstance(name, dtypes.DType): raise TypeError("'name' must be a string, not a DType. " @@ -319,19 +311,6 @@ class Policy(object): raise TypeError("'name' must be a string, but got: %s" % (name,)) self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) - - if loss_scale == 'auto': - loss_scale = 'dynamic' if name == 'mixed_float16' else None - self._using_default_loss_scale = True - else: - self._using_default_loss_scale = False - if loss_scale and self._compute_dtype not in (None, 'float16'): - tf_logging.warn('Creating a Policy with a loss scale is only useful for ' - 'float16 policies. You passed loss_scale=%r for policy ' - '%s. Consider not passing any loss_scale instead.' % - (loss_scale, name)) - self._loss_scale = keras_loss_scale_module.get(loss_scale) - if name in ('mixed_float16', 'mixed_bloat16'): device_compatibility_check.log_device_compatibility_check(name) @@ -427,32 +406,98 @@ class Policy(object): return self._compute_dtype @property - def should_cast_variables(self): - """Returns True if variables should be casted. + def name(self): + """Returns the name of this policy.""" + return self._name - This is true if the variable dtype is not the same as the compute dtype. + def __repr__(self): + return '' % self._name - Returns: - True, if variables should be casted. + def get_config(self): + return {'name': self.name} + + @classmethod + def from_config(cls, config, custom_objects=None): + del custom_objects + if 'loss_scale' in config: + config = config.copy() + # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We + # silently drop it. + del config['loss_scale'] + return cls(**config) + + +@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) +class PolicyV1(Policy): + """A deprecated dtype policy for a Keras layer. + + Warning: This class is now deprecated and will be removed soon. Please use the + non-experimental class `tf.keras.mixed_precision.Policy` instead. + + The difference between this class and the non-experimental class is that this + class has a `loss_scale` field and the non-experimental class does not. The + loss scale is only used by `tf.keras.Model.compile`, which automatically wraps + the optimizer with a `LossScaleOptimizer` if the optimzier is not already a + `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile` + instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is + "mixed_float16". + + When deserializing objects with an experimental policy using functions like + `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as + the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale + will silently be dropped. This is so that SavedModels that are generated + with an expeirmental policy can be restored after the experimental policy is + removed. + """ + + def __init__(self, name, loss_scale='auto'): + """Constructs the policy. + + The `name` argument determines the compute and variable dtype, the default + loss scale, and has no additional effect on the Policy. The compute and + variable dtypes can only be specified through `name`, and cannot be + specified directly. + + Args: + name: A string. Can be one of the following values: + * Any dtype name, such as 'float32' or 'float64'. Both the variable and + compute dtypes will be that dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + a dynamic loss scale is used. These policies are used for mixed + precision training. + loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which + uses a `FixedLossScale`), the string "dynamic" (which uses a + `DynamicLossScale`), or None (which uses no loss scale). Defaults to + `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then + use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only + `tf.keras.Model`s, not layers, use the loss scale, and it is only used + during `Model.fit`, `Model.train_on_batch`, and other similar methods. """ - return self.variable_dtype != self.compute_dtype + super(PolicyV1, self).__init__(name) + if loss_scale == 'auto': + loss_scale = 'dynamic' if name == 'mixed_float16' else None + self._using_default_loss_scale = True + else: + self._using_default_loss_scale = False + if loss_scale and self._compute_dtype not in (None, 'float16'): + tf_logging.warn('Creating a Policy with a loss scale is only useful for ' + 'float16 policies. You passed loss_scale=%r for policy ' + '%s. Consider not passing any loss_scale instead.' % + (loss_scale, name)) + self._loss_scale = keras_loss_scale_module.get(loss_scale) @property def loss_scale(self): """Returns the loss scale of this Policy. Returns: - A `tf.mixed_precision.experimental.LossScale`, or None. + A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None. """ return self._loss_scale - @property - def name(self): - """Returns the name of this policy.""" - return self._name - def __repr__(self): - return '' % (self._name, self.loss_scale) + return '' % (self._name, self.loss_scale) def get_config(self): config = { @@ -481,7 +526,8 @@ class Policy(object): _global_policy = None -@keras_export('keras.mixed_precision.experimental.global_policy', v1=[]) +@keras_export('keras.mixed_precision.global_policy', + 'keras.mixed_precision.experimental.global_policy', v1=[]) def global_policy(): """Returns the global Policy. @@ -496,8 +542,7 @@ def global_policy(): first time the layer is called. This behavior matches the behavior that existed in TensorFlow 1. - See `tf.keras.mixed_precision.experimental.Policy` for more information on - policies. + See `tf.keras.mixed_precision.Policy` for more information on policies. Returns: The global Policy. @@ -526,7 +571,8 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): 'customizable.'.format(policy=policy)) -@keras_export('keras.mixed_precision.experimental.set_policy', v1=[]) +@keras_export('keras.mixed_precision.set_global_policy', + 'keras.mixed_precision.experimental.set_policy', v1=[]) def set_policy(policy): """Sets the global Policy. @@ -539,7 +585,7 @@ def set_policy(policy): `'int32'` and `'complex64'` cannot be set as the global policy because most layers do not support such policies. - See `tf.keras.mixed_precision.experimental.Policy` for more information. + See `tf.keras.mixed_precision.Policy` for more information. Args: policy: A Policy, or a string that will be converted to a Policy.. @@ -552,7 +598,8 @@ def set_policy(policy): '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"') if policy is not None and not isinstance(policy, Policy): policy = Policy(policy) - is_mixed_policy = policy is not None and policy.should_cast_variables + is_mixed_policy = (policy is not None and + policy.compute_dtype != policy.variable_dtype) if is_mixed_policy: _check_if_mixed_precision_graph_rewrite_is_enabled(policy) if (policy is not None and policy.compute_dtype is not None and @@ -596,8 +643,8 @@ def _policy_equivalent_to_dtype(policy): """Returns True if the Policy is equivalent to a single dtype. A policy is equivalent to a single dtype if the policy's compute and variable - dtypes are the same and the policy does not cause the layer/model to have - additional behavior, such as loss scaling. + dtypes are the same and the policy's type is Policy and not a subclass of + Policy (such as PolicyV1). The "_infer" policy is considered equivalent to a single dtype. @@ -628,7 +675,7 @@ def deserialize(config, custom_objects=None): return Policy(config) if config is None: return Policy('_infer') - module_objects = {'Policy': Policy} + module_objects = {'Policy': Policy, 'PolicyV1': Policy} return generic_utils.deserialize_keras_object( config, module_objects=module_objects, diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/policy_test.py similarity index 86% rename from tensorflow/python/keras/mixed_precision/experimental/policy_test.py rename to tensorflow/python/keras/mixed_precision/policy_test.py index 9ebcc3558e6..85c41a2adeb 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/policy_test.py @@ -27,8 +27,8 @@ from tensorflow.python.framework import ops from tensorflow.python.keras import combinations from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer_utils -from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check -from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy +from tensorflow.python.keras.mixed_precision import device_compatibility_check +from tensorflow.python.keras.mixed_precision import policy as mp_policy from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -60,14 +60,21 @@ class PolicyTest(test.TestCase, parameterized.TestCase): @testing_utils.enable_v2_dtype_behavior def test_repr(self): - for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): + # Test Policy repr + for policy in ('float32', 'int8', 'mixed_float16', 'mixed_bfloat16', + '_infer'): self.assertEqual(repr(mp_policy.Policy(policy)), - '' % policy) - self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)), - '') + '' % policy) + + # Test PolicyV1 repr + for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): + self.assertEqual(repr(mp_policy.PolicyV1(policy)), + '' % policy) + self.assertEqual(repr(mp_policy.PolicyV1('float16', loss_scale=2)), + '') self.assertStartsWith( - repr(mp_policy.Policy('mixed_float16')), - '?@[\\]^_`{|}~\\t\\n``, - includes basic punctuation, tabs, and newlines. + punctuation. Default: + ``` + '!"#$%&()*+,-./:;<=>?@[\]^_`{|}~\t\n + ```, + includes basic punctuation, tabs, and newlines. lower: boolean. Whether to set the text to lowercase. split: str. Separator for word splitting. diff --git a/tensorflow/python/keras/protobuf/BUILD b/tensorflow/python/keras/protobuf/BUILD index df253252f36..b7d85419fb9 100644 --- a/tensorflow/python/keras/protobuf/BUILD +++ b/tensorflow/python/keras/protobuf/BUILD @@ -14,3 +14,10 @@ tf_proto_library( srcs = ["projector_config.proto"], cc_api_version = 2, ) + +tf_proto_library( + name = "saved_metadata_proto", + srcs = ["saved_metadata.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], +) diff --git a/tensorflow/python/keras/protobuf/saved_metadata.proto b/tensorflow/python/keras/protobuf/saved_metadata.proto new file mode 100644 index 00000000000..41684bbd627 --- /dev/null +++ b/tensorflow/python/keras/protobuf/saved_metadata.proto @@ -0,0 +1,33 @@ +// Protobuf containing the metadata for each Keras object saved in a SavedModel. + +syntax = "proto3"; + +package third_party.tensorflow.python.keras.protobuf; + +import "tensorflow/core/framework/versions.proto"; + +message SavedMetadata { + // Nodes represent trackable objects in the SavedModel. The data for every + // Keras object is stored. + repeated SavedObject nodes = 1; +} + +// Metadata of an individual Keras object. +message SavedObject { + // Version defined by the code serializing this Keras object. + .tensorflow.VersionDef version = 1; + // Index of the node in the SavedModel SavedObjectGraph. + int32 node_id = 2; + // String path from root (e.g. "root.child_layer") + string node_path = 3; + + // Identifier to determine loading function. + // Currently supported identifiers: + // _tf_keras_layer, _tf_keras_input_layer, _tf_keras_rnn_layer, + // _tf_keras_metric, _tf_keras_network, _tf_keras_model, + // _tf_keras_sequential + string identifier = 4; + // Metadata containing a JSON-serialized object with the non-TensorFlow + // attributes for this Keras object. + string metadata = 5; +} diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 4b7d0e2fef5..2e462662e91 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -57,7 +57,8 @@ py_library( "//tensorflow/python/keras:optimizers", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras/engine:input_spec", - "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", + "//tensorflow/python/keras/mixed_precision:autocast_variable", + "//tensorflow/python/keras/protobuf:saved_metadata_proto_py", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:metrics_utils", "//tensorflow/python/keras/utils:mode_keys", @@ -73,6 +74,9 @@ tf_py_test( srcs = ["metrics_serialization_test.py"], python_version = "PY3", shard_count = 8, + tags = [ + "notsan", # TODO(b/170870790) + ], deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -206,7 +210,6 @@ tf_py_test( size = "small", srcs = ["saved_model/json_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":saving", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py index 9ed90bc4999..21a167683ee 100644 --- a/tensorflow/python/keras/saving/hdf5_format_test.py +++ b/tensorflow/python/keras/saving/hdf5_format_test.py @@ -58,7 +58,7 @@ except ImportError: @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): - @keras_parameterized.run_with_all_saved_model_formats + @keras_parameterized.run_with_all_weight_formats def test_weight_loading(self): temp_dir = self.get_temp_dir() self.addCleanup(shutil.rmtree, temp_dir) @@ -410,9 +410,14 @@ class TestWholeModelSaving(keras_parameterized.TestCase): def test_save_and_load(self): saved_model_dir = self._save_model_dir() save_format = testing_utils.get_save_format() + save_kwargs = testing_utils.get_save_kwargs() - if save_format == 'h5' and testing_utils.get_model_type() == 'subclass': - return # HDF5 format currently does not allow saving classed models. + if ((save_format == 'h5' or not save_kwargs.get('save_traces', True)) and + testing_utils.get_model_type() == 'subclass'): + # HDF5 format currently does not allow saving subclassed models. + # When saving with `save_traces=False`, the subclassed model must have a + # get_config/from_config, which the autogenerated model does not have. + return with self.cached_session(): model = testing_utils.get_model_from_layers( @@ -440,7 +445,9 @@ class TestWholeModelSaving(keras_parameterized.TestCase): model.train_on_batch(x, y) out = model.predict(x) - keras.models.save_model(model, saved_model_dir, save_format=save_format) + keras.models.save_model( + model, saved_model_dir, save_format=save_format, + **save_kwargs) loaded_model = keras.models.load_model(saved_model_dir) self._assert_same_weights_and_metrics(model, loaded_model) diff --git a/tensorflow/python/keras/saving/model_config.py b/tensorflow/python/keras/saving/model_config.py index 63f82b404a4..facc95b22f9 100644 --- a/tensorflow/python/keras/saving/model_config.py +++ b/tensorflow/python/keras/saving/model_config.py @@ -34,6 +34,15 @@ except ImportError: @keras_export('keras.models.model_from_config') def model_from_config(config, custom_objects=None): """Instantiates a Keras model from its config. + + Usage: + ``` + # for a Functional API model + tf.keras.Model().from_config(model.get_config()) + + # for a Sequential model + tf.keras.Sequential().from_config(model.get_config()) + ``` Arguments: config: Configuration dictionary. diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index c0c69c4e715..9f03197920f 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -52,9 +52,14 @@ def save_model(model, include_optimizer=True, save_format=None, signatures=None, - options=None): + options=None, + save_traces=True): + # pylint: disable=line-too-long """Saves a model as a TensorFlow SavedModel or HDF5 file. + See the [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) + for details. + Usage: >>> model = tf.keras.Sequential([ @@ -65,28 +70,38 @@ def save_model(model, >>> x = tf.random.uniform((10, 3)) >>> assert np.allclose(model.predict(x), loaded_model.predict(x)) - The saved model contains: + The SavedModel and HDF5 file contains: - - the model's configuration (topology) - - the model's weights - - the model's optimizer's state (if any) + - the model's configuration (topology) + - the model's weights + - the model's optimizer's state (if any) - Thus the saved model can be reinstantiated in - the exact same state, without any of the code - used for model definition or training. + Thus models can be reinstantiated in the exact same state, without any of the + code used for model definition or training. Note that the model weights may have different scoped names after being loaded. Scoped names include the model/layer names, such as `"dense_1/kernel:0"`. It is recommended that you use the layer properties to access specific variables, e.g. `model.get_layer("dense_1").kernel`. - _SavedModel serialization_ + __SavedModel serialization format__ - The SavedModel serialization path uses `tf.saved_model.save` to save the model - and all trackable objects attached to the model (e.g. layers and variables). - `@tf.function`-decorated methods are also saved. Additional trackable objects - and functions are added to the SavedModel to allow the model to be - loaded back as a Keras Model object. + Keras SavedModel uses `tf.saved_model.save` to save the model and all + trackable objects attached to the model (e.g. layers and variables). The model + config, weights, and optimizer are saved in the SavedModel. Additionally, for + every Keras layer attached to the model, the SavedModel stores: + + * the config and metadata -- e.g. name, dtype, trainable status + * traced call and loss functions, which are stored as TensorFlow subgraphs. + + The traced functions allow the SavedModel format to save and load custom + layers without the original class definition. + + You can choose to not save the traced functions by disabling the `save_traces` + option. This will decrease the time it takes to save the model and the + amount of disk space occupied by the output SavedModel. If you enable this + option, then you _must_ provide all custom class definitions when loading + the model. See the `custom_objects` argument in `tf.keras.models.load_model`. Arguments: model: Keras model instance to be saved. @@ -102,12 +117,19 @@ def save_model(model, signatures: Signatures to save with the SavedModel. Applicable to the 'tf' format only. Please see the `signatures` argument in `tf.saved_model.save` for details. - options: Optional `tf.saved_model.SaveOptions` object that specifies - options for saving to SavedModel. + options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` + object that specifies options for saving to SavedModel. + save_traces: (only applies to SavedModel format) When enabled, the + SavedModel will store the function traces for each layer. This + can be disabled, so that only the configs of each layer are stored. + Defaults to `True`. Disabling this will decrease serialization time and + reduce file size, but it requires that all custom layers/models + implement a `get_config()` method. Raises: ImportError: If save format is hdf5, and h5py is not available. """ + # pylint: enable=line-too-long from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top default_format = 'tf' if tf2.enabled() else 'h5' @@ -132,7 +154,7 @@ def save_model(model, model, filepath, overwrite, include_optimizer) else: saved_model_save.save(model, filepath, overwrite, include_optimizer, - signatures, options) + signatures, options, save_traces) @keras_export('keras.models.load_model') diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 7330a4a189f..fcd2003aab8 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -290,7 +290,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_saving_model_with_custom_object(self): - with generic_utils.custom_object_scope(): + with generic_utils.custom_object_scope(), self.cached_session(): @generic_utils.register_keras_serializable() class CustomLoss(losses.MeanSquaredError): diff --git a/tensorflow/python/keras/saving/saved_model/base_serialization.py b/tensorflow/python/keras/saving/saved_model/base_serialization.py index 0065e6d786e..0b26fe774b1 100644 --- a/tensorflow/python/keras/saving/saved_model/base_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/base_serialization.py @@ -22,6 +22,7 @@ import abc import six from tensorflow.python.keras.saving.saved_model import json_utils +from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.training.tracking import tracking @@ -71,6 +72,9 @@ class SavedModelSaver(object): A dictionary mapping attribute names to trackable objects. The entire list of attributes are listed in the `saved_model._LayerAttributes` class. """ + if not utils.should_save_traces(): + return {} + return self.objects_to_serialize(serialization_cache) def list_functions_for_serialization(self, serialization_cache): @@ -84,6 +88,9 @@ class SavedModelSaver(object): A dictionary mapping attribute names to `Function` or `ConcreteFunction`. """ + if not utils.should_save_traces(): + return {} + fns = self.functions_to_serialize(serialization_cache) # The parent AutoTrackable class saves all user-defined tf.functions, and diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index 4216457bf28..73f3ba250a4 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import base_serialization from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index e0426e74f6b..cb6d340ea03 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -20,6 +20,7 @@ from __future__ import print_function import re import types +from tensorflow.core.framework import versions_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function as defun from tensorflow.python.framework import ops @@ -28,6 +29,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import input_spec +from tensorflow.python.keras.protobuf import saved_metadata_pb2 from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import json_utils @@ -38,9 +40,11 @@ from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load as tf_load +from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import revived_types from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking.tracking import delete_tracking from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -117,8 +121,33 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints - model = tf_load.load_internal( - path, options=options, loader_cls=KerasObjectLoader) + # The Keras metadata file is not yet saved, so create it from the SavedModel. + metadata = saved_metadata_pb2.SavedMetadata() + meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] + object_graph_def = meta_graph_def.object_graph_def + # TODO(kathywu): When the keras metadata file is saved, load it directly + # instead of calling the _read_legacy_metadata function. + _read_legacy_metadata(object_graph_def, metadata) + + if not metadata.nodes: + # When there are no Keras objects, return the results from the core loader + return tf_load.load(path, options=options) + + # Recreate layers and metrics using the info stored in the metadata. + keras_loader = KerasObjectLoader(metadata, object_graph_def) + keras_loader.load_layers() + + # Generate a dictionary of all loaded nodes. + nodes_to_load = {'root': None} + for node_id, loaded_node in keras_loader.loaded_nodes.items(): + nodes_to_load[keras_loader.get_path(node_id)] = loaded_node + loaded = tf_load.load_partial(path, nodes_to_load, options=options) + + # Finalize the loaded layers and remove the extra tracked dependencies. + keras_loader.finalize_objects() + keras_loader.del_tracking() + + model = loaded['root'] # pylint: disable=protected-access if isinstance(model, training_lib.Model) and compile: @@ -143,6 +172,45 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin return model +def _read_legacy_metadata(object_graph_def, metadata): + """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" + # Older SavedModels store the metadata directly in the proto instead of the + # separate pb file. + node_paths = _generate_object_paths(object_graph_def) + for node_id, proto in enumerate(object_graph_def.nodes): + if (proto.WhichOneof('kind') == 'user_object' and + proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS): + metadata.nodes.add( + node_id=node_id, + node_path=node_paths[node_id], + version=versions_pb2.VersionDef( + producer=1, min_consumer=1, bad_consumers=[]), + identifier=proto.user_object.identifier, + metadata=proto.user_object.metadata) + + +def _generate_object_paths(object_graph_def): + """Traverses through an ObjectGraphDef and builds a map of all node paths.""" + paths = {0: 'root'} + nodes_to_visit = [0] + visited_nodes = set([]) + + while nodes_to_visit: + current_node = nodes_to_visit.pop() + # if current_node in visited_nodes: + # continue + visited_nodes.add(current_node) + current_path = paths[current_node] + for reference in object_graph_def.nodes[current_node].children: + if reference.node_id in paths: + continue + paths[reference.node_id] = '{}.{}'.format(current_path, + reference.local_name) + nodes_to_visit.append(reference.node_id) + + return paths + + def _is_graph_network(layer): """Determines whether the layer is a graph network.""" # pylint: disable=protected-access @@ -154,7 +222,7 @@ def _is_graph_network(layer): return False -class KerasObjectLoader(tf_load.Loader): +class KerasObjectLoader(object): """Loader that recreates Keras objects (e.g. layers, models). Layers and models are revived from either the config or SavedModel following @@ -173,15 +241,17 @@ class KerasObjectLoader(tf_load.Loader): """ - def __init__(self, *args, **kwargs): - # Maps node id -> (node, revive setter function) - # Nodes recreated from the config may generate other nodes. This list - # records all nodes that were generated directly/indirectly from the config, - # so that they do not get recreated multiple times. - self._nodes_recreated_from_config = {} + def __init__(self, metadata, object_graph_def): + self._metadata = metadata + self._proto = object_graph_def + + self._node_paths = {node_data.node_id: node_data.node_path + for node_data in metadata.nodes} + self.loaded_nodes = {} # Maps node path -> loaded node + # Store all node ids that have already been traversed when tracking nodes # that were recreated from the config. - self._traversed_nodes_from_config = [] + self._traversed_nodes_from_config = set() # Maps model id -> (blank model obj, list of child layer or their node ids) # This tracks all layers in functional and sequential models. These models @@ -189,8 +259,8 @@ class KerasObjectLoader(tf_load.Loader): self.model_layer_dependencies = {} self._models_to_reconstruct = [] - super(KerasObjectLoader, self).__init__(*args, **kwargs) - + def del_tracking(self): + """Removes tracked references that are only used when loading the model.""" # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still @@ -199,50 +269,33 @@ class KerasObjectLoader(tf_load.Loader): # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. - for node in self._nodes: + for node in self.loaded_nodes.values(): + node = node[0] if not isinstance(node, base_layer.Layer): + # Loaded nodes can contain other trackable objects created when + # loading layers from the config, such as variables. continue for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name) - def _load_all(self): - """Reconstruct the object graph from the SavedModel.""" - # Load layer and model objects from either config or SavedModel. The objects - # loaded from config may create variables / other objects during - # initialization. These are recorded in `_nodes_recreated_from_config`. - self._layer_nodes = self._load_layers() - - # Load all other nodes and functions. - super(KerasObjectLoader, self)._load_all() - - # Finish setting up layers and models. See function docstring for more info. - self._finalize_objects() - - @property - def _expect_partial_checkpoint(self): - return True - - def _recreate(self, proto, node_id): - """Creates a Python object from a SavedObject protocol buffer.""" - if node_id in self._layer_nodes: - return self._layer_nodes[node_id] - - if node_id in self._nodes_recreated_from_config: - obj, setter = self._nodes_recreated_from_config[node_id] - - # Overwrite variable names with the ones saved in the SavedModel. - if proto.WhichOneof('kind') == 'variable' and proto.variable.name: - obj._handle_name = proto.variable.name + ':0' # pylint: disable=protected-access - else: - obj, setter = super(KerasObjectLoader, self)._recreate(proto, node_id) - return obj, setter + if isinstance(node, functional_lib.Functional): + # Delete the temporary layer dependencies, which were used to restore + # the checkpointed values. When the model is live, the user can delete + # or add layers to the model at any time, so these layer dependencies + # may be obsolete. + dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access + for name in dependencies: + if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: + delete_tracking(node, name) def _add_children_recreated_from_config(self, obj, proto, node_id): """Recursively records objects recreated from config.""" # pylint: disable=protected-access if node_id in self._traversed_nodes_from_config: return - self._traversed_nodes_from_config.append(node_id) + + parent_path = self._node_paths[node_id] + self._traversed_nodes_from_config.add(node_id) obj._maybe_initialize_trackable() if isinstance(obj, base_layer.Layer) and not obj.built: metadata = json_utils.decode(proto.user_object.metadata) @@ -253,21 +306,23 @@ class KerasObjectLoader(tf_load.Loader): # Look for direct children for reference in proto.children: obj_child = obj._lookup_dependency(reference.local_name) - children.append((obj_child, reference.node_id)) + children.append((obj_child, reference.node_id, reference.local_name)) # Add metrics that may have been added to the layer._metrics list. # This is stored in the SavedModel as layer.keras_api.layer_metrics in # SavedModels created after Tf 2.2. metric_list_node_id = self._search_for_child_node( - node_id, [constants.KERAS_ATTR, 'layer_metrics'], raise_error=False) + node_id, [constants.KERAS_ATTR, 'layer_metrics']) if metric_list_node_id is not None and hasattr(obj, '_metrics'): obj_metrics = {m.name: m for m in obj._metrics} for reference in self._proto.nodes[metric_list_node_id].children: metric = obj_metrics.get(reference.local_name) if metric is not None: - children.append((metric, reference.node_id)) + metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, + reference.local_name) + children.append((metric, reference.node_id, metric_path)) - for (obj_child, child_id) in children: + for (obj_child, child_id, child_name) in children: child_proto = self._proto.nodes[child_id] if not isinstance(obj_child, trackable.Trackable): @@ -281,47 +336,57 @@ class KerasObjectLoader(tf_load.Loader): setter = setattr # pylint: enable=protected-access - if (child_id in self._nodes_recreated_from_config and - self._nodes_recreated_from_config[child_id][0] is not obj_child): - # This means that the same trackable object is referenced by two - # different objects that were recreated from the config. - logging.warn('Looks like there is an object (perhaps variable or layer)' - ' that is shared between different layers/models. This ' - 'may cause issues when restoring the variable values.' - 'Object: {}'.format(obj_child)) - self._nodes_recreated_from_config[child_id] = ( - obj_child, self._config_node_setter(setter)) + if child_id in self.loaded_nodes: + if self.loaded_nodes[child_id][0] is not obj_child: + # This means that the same trackable object is referenced by two + # different objects that were recreated from the config. + logging.warn('Looks like there is an object (perhaps variable or ' + 'layer) that is shared between different layers/models. ' + 'This may cause issues when restoring the variable ' + 'values. Object: {}'.format(obj_child)) + continue + + # Overwrite variable names with the ones saved in the SavedModel. + if (child_proto.WhichOneof('kind') == 'variable' and + child_proto.variable.name): + obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access + + if isinstance(obj_child, data_structures.TrackableDataStructure): + setter = lambda *args: None + + child_path = '{}.{}'.format(parent_path, child_name) + self._node_paths[child_id] = child_path self._add_children_recreated_from_config( obj_child, child_proto, child_id) + self.loaded_nodes[child_id] = obj_child, setter - def _load_layers(self): - layers = {} - + def load_layers(self): + """Load all layer nodes from the metadata.""" # Load metrics after models and layers, since it's likely that models # and layers will create the metric when initialized (this avoids wasting # time by creating objects multiple times). metric_list = [] - for node_id, proto in enumerate(self._proto.nodes): - if (proto.WhichOneof('kind') != 'user_object' or - proto.user_object.identifier not in KERAS_OBJECT_IDENTIFIERS): - continue - if proto.user_object.identifier == '_tf_keras_metric': - metric_list.append((node_id, proto)) + for node_metadata in self._metadata.nodes: + if node_metadata.identifier == '_tf_keras_metric': + metric_list.append(node_metadata) continue - layers[node_id] = self._load_layer(proto.user_object, node_id) + self.loaded_nodes[node_metadata.node_id] = self._load_layer( + node_metadata.node_id, node_metadata.identifier, + node_metadata.metadata) - for node_id, proto in metric_list: - layers[node_id] = self._load_layer(proto.user_object, node_id) - return layers + for node_metadata in metric_list: + self.loaded_nodes[node_metadata.node_id] = self._load_layer( + node_metadata.node_id, node_metadata.identifier, + node_metadata.metadata) - def _load_layer(self, proto, node_id): + def _load_layer(self, node_id, identifier, metadata): """Load a single layer from a SavedUserObject proto.""" - metadata = json_utils.decode(proto.metadata) + metadata = json_utils.decode(metadata) # If node was already created - if node_id in self._nodes_recreated_from_config: - node, setter = self._nodes_recreated_from_config[node_id] + if node_id in self.loaded_nodes: + node, setter = self.loaded_nodes[node_id] # Revive setter requires the object to have a `_serialized_attributes` # property. Add it here. @@ -329,15 +394,17 @@ class KerasObjectLoader(tf_load.Loader): config = metadata.get('config') if _is_graph_network(node) and generic_utils.validate_config(config): - self.model_layer_dependencies[node_id] = ( - node, self._get_child_layer_node_ids(node_id, node.name)) + child_nodes = self._get_child_layer_node_ids(node_id) + self.model_layer_dependencies[node_id] = (node, child_nodes) + if not child_nodes: + self._models_to_reconstruct.append(node_id) return node, setter # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. - obj, setter = self._revive_from_config(proto.identifier, metadata, node_id) + obj, setter = self._revive_from_config(identifier, metadata, node_id) if obj is None: - obj, setter = revive_custom_object(proto.identifier, metadata) + obj, setter = revive_custom_object(identifier, metadata) # Add an attribute that stores the extra functions/objects saved in the # SavedModel. Most of these functions/objects are ignored, but some are @@ -349,7 +416,7 @@ class KerasObjectLoader(tf_load.Loader): def _revive_from_config(self, identifier, metadata, node_id): """Revives a layer/model from config, or returns None.""" if identifier == '_tf_keras_metric': - obj = self._revive_metric_from_config(metadata, node_id) + obj = self._revive_metric_from_config(metadata) else: obj = ( self._revive_graph_network(metadata, node_id) or @@ -359,7 +426,6 @@ class KerasObjectLoader(tf_load.Loader): return None, None setter = self._config_node_setter(_revive_setter) - self._nodes_recreated_from_config[node_id] = obj, setter self._add_children_recreated_from_config( obj, self._proto.nodes[node_id], node_id) return obj, setter @@ -394,9 +460,10 @@ class KerasObjectLoader(tf_load.Loader): # Record this model and its layers. This will later be used to reconstruct # the model. - layers = self._get_child_layer_node_ids(node_id, model.name) + layers = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (model, layers) - + if not layers: + self._models_to_reconstruct.append(node_id) return model def _revive_layer_from_config(self, metadata, node_id): @@ -449,7 +516,8 @@ class KerasObjectLoader(tf_load.Loader): return obj - def _revive_metric_from_config(self, metadata, node_id): + def _revive_metric_from_config(self, metadata): + """Revives a metric object using the config saved in the metadata.""" class_name = compat.as_str(metadata['class_name']) config = metadata.get('config') @@ -490,7 +558,10 @@ class KerasObjectLoader(tf_load.Loader): if node_id not in self.model_layer_dependencies: self._add_object_graph_edges(proto, node_id) - def _finalize_objects(self): + def get_path(self, node_id): + return self._node_paths[node_id] + + def finalize_objects(self): """Finish setting up Keras objects. This function is executed after all objects and functions have been created. @@ -504,7 +575,7 @@ class KerasObjectLoader(tf_load.Loader): # functions and losses to each object, and sets model inputs/outputs. layers_revived_from_config = [] layers_revived_from_saved_model = [] - for node_id, node in enumerate(self._nodes): + for node_id, (node, _) in self.loaded_nodes.items(): if (not isinstance(node, base_layer.Layer) or # Don't finalize models until all layers have finished loading. node_id in self.model_layer_dependencies): @@ -517,10 +588,10 @@ class KerasObjectLoader(tf_load.Loader): elif isinstance(node, metrics.Metric): continue - if node_id in self._nodes_recreated_from_config: - layers_revived_from_config.append(node) - else: + if isinstance(node, (RevivedLayer, RevivedInputLayer)): layers_revived_from_saved_model.append(node) + else: + layers_revived_from_config.append(node) _finalize_saved_model_layers(layers_revived_from_saved_model) _finalize_config_layers(layers_revived_from_config) @@ -539,13 +610,13 @@ class KerasObjectLoader(tf_load.Loader): self._models_to_reconstruct.append(model_id) def _reconstruct_all_models(self): + """Reconstructs the network structure of all models.""" all_initialized_models = set() while self._models_to_reconstruct: model_id = self._models_to_reconstruct.pop(0) all_initialized_models.add(model_id) model, layers = self.model_layer_dependencies[model_id] self._reconstruct_model(model_id, model, layers) - self._add_object_graph_edges(self._proto.nodes[model_id], model_id) _finalize_config_layers([model]) if all_initialized_models != set(self.model_layer_dependencies.keys()): @@ -560,10 +631,17 @@ class KerasObjectLoader(tf_load.Loader): .format(uninitialized_model_names)) def _reconstruct_model(self, model_id, model, layers): + """Reconstructs the network structure.""" config = json_utils.decode( self._proto.nodes[model_id].user_object.metadata)['config'] - if isinstance(model, models_lib.Sequential): - if not isinstance(layers[0], input_layer.InputLayer): + + # Set up model inputs + if model.inputs: + # Inputs may already be created if the model is instantiated in another + # object's __init__. + pass + elif isinstance(model, models_lib.Sequential): + if not layers or not isinstance(layers[0], input_layer.InputLayer): if config['layers'][0]['class_name'] == 'InputLayer': layers.insert(0, input_layer.InputLayer.from_config( config['layers'][0]['config'])) @@ -576,13 +654,13 @@ class KerasObjectLoader(tf_load.Loader): name=layers[0].name + '_input')) model.__init__(layers, name=config['name']) if not model.inputs: - first_layer = self._get_child_layer_node_ids(model_id, model.name)[0] + first_layer = self._get_child_layer_node_ids(model_id)[0] input_specs = self._infer_inputs(first_layer) input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) model._set_inputs(input_specs) # pylint: disable=protected-access if not model.built and not isinstance(input_specs, dict): model.build(input_shapes) - else: + else: # Reconstruct functional model (inputs, outputs, created_layers) = functional_lib.reconstruct_from_config( config, created_layers={layer.name: layer for layer in layers}) @@ -595,15 +673,31 @@ class KerasObjectLoader(tf_load.Loader): # Unblock models that are dependent on this model. self._unblock_model_reconstruction(model_id, model) - def _get_child_layer_node_ids(self, node_id, name): - """Returns the node ids of the children layers of a node.""" - # Retrieve the node id of layer.keras_api.layers. - layer_list = self._search_for_child_node( - node_id, [constants.KERAS_ATTR, 'layers'], name) - return [node.node_id for node in self._proto.nodes[layer_list].children] + def _get_child_layer_node_ids(self, node_id): + """Returns the node ids of each layer in a Sequential/Functional model.""" + # Sequential and Functional track layers with names following the format + # "layer-N". Use this to generate the list of layers. + num_layers = 0 + child_layers = {} + pattern = re.compile('layer-(\\d+)') - def _search_for_child_node( - self, parent_id, path_to_child, debugging_name=None, raise_error=True): + for child in self._proto.nodes[node_id].children: + m = pattern.match(child.local_name) + if m is None: + continue + layer_n = int(m.group(1)) + num_layers = max(layer_n + 1, num_layers) + child_layers[layer_n] = child.node_id + + ordered = [] + for n in range(num_layers): + child = child_layers.get(n) + if child is None: + break + ordered.append(child) + return ordered + + def _search_for_child_node(self, parent_id, path_to_child): """Returns node id of child node. A helper method for traversing the object graph proto. @@ -621,37 +715,23 @@ class KerasObjectLoader(tf_load.Loader): Args: parent_id: node id of parent node path_to_child: list of children names. - debugging_name: the name to print out when raising an error. - raise_error: Whether to raise an error if the child isn't found. Returns: node_id of child, or None if child isn't found. - - Raises: - ValueError: if child isn't found and raise_error is True. """ if not path_to_child: return parent_id for child in self._proto.nodes[parent_id].children: if child.local_name == path_to_child[0]: - return self._search_for_child_node(child.node_id, path_to_child[1:], - debugging_name, raise_error) - - if raise_error: - raise ValueError( - 'Error when loading {}: could not find attribute {}.\n' - 'Most likely this object was serialized incorrectly.' - .format(debugging_name or path_to_child[0], path_to_child[0])) - else: - return None + return self._search_for_child_node(child.node_id, path_to_child[1:]) + return None def _infer_inputs(self, layer_node_id, convert_to_shapes=False): """Infers input shape of layer from SavedModel functions.""" coder = nested_structure_coder.StructureCoder() call_fn_id = self._search_for_child_node( - layer_node_id, ['call_and_return_all_conditional_losses'], None, - raise_error=False) + layer_node_id, ['call_and_return_all_conditional_losses']) if call_fn_id is None: return None @@ -738,9 +818,9 @@ def _unable_to_call_layer_due_to_serialization_issue( """ raise ValueError( - 'Cannot call {} ({}), because the call function was not serialized to ' - 'the SavedModel (due to lack information about the inputs). Please try ' - 'one of the following methods to fix the serialization:' + 'Cannot call custom layer {} of type {}, because the call function was ' + 'not serialized to the SavedModel.' + 'Please try one of the following methods to fix this issue:' '\n\n(1) Implement `get_config` and `from_config` in the layer/model ' 'class, and pass the object to the `custom_objects` argument when ' 'loading the model. For more details, see: ' @@ -749,7 +829,7 @@ def _unable_to_call_layer_due_to_serialization_issue( 'and not `__call__`. The input shape and dtype will be automatically ' 'recorded when the object is called, and used when saving. To manually ' 'specify the input shape/dtype, decorate the call function with ' - '`@tf.function(input_signature=...)`.'.format(layer.name, layer)) + '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer))) def _finalize_config_layers(layers): @@ -917,8 +997,11 @@ def _revive_setter(layer, name, value): elif (isinstance(layer, functional_lib.Functional) and re.match(r'^layer(_with_weights)?-[\d+]', name) is not None): # Edges named "layer-n" or "layer_with_weights-n", which are tracked in - # network._track_layers, should not be added as an attribute. - pass + # network._track_layers, should not be added as an attribute. They should + # be temporarily added as a dependency so that checkpointed values can be + # restored. These dependencies are manually deleted in + # KerasObjectLoader.del_tracking. + layer._track_trackable(value, name) # pylint: disable=protected-access elif getattr(layer, name, None) is not None: # Don't overwrite already defined attributes. pass diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py index a40856cbf54..16984a2221b 100644 --- a/tensorflow/python/keras/saving/saved_model/save.py +++ b/tensorflow/python/keras/saving/saved_model/save.py @@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.keras import backend as K from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import save_impl +from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.saved_model import save as save_lib @@ -38,7 +39,7 @@ training_lib = LazyLoader( def save(model, filepath, overwrite, include_optimizer, signatures=None, - options=None): + options=None, save_traces=True): """Saves a model as a SavedModel to the filepath. Args: @@ -49,8 +50,14 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, signatures: Signatures to save with the SavedModel. Applicable to the 'tf' format only. Please see the `signatures` argument in `tf.saved_model.save` for details. - options: Optional `tf.saved_model.SaveOptions` object that specifies - options for saving to SavedModel. + options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` + object that specifies options for saving to SavedModel. + save_traces: (only applies to SavedModel format) When enabled, the + SavedModel will store the function traces for each layer. This + can be disabled, so that only the configs of each layer are stored. + Defaults to `True`. Disabling this will decrease serialization time + and reduce file size, but it requires that all custom layers/models + implement a `get_config()` method. Raises: ValueError: if the model's inputs have not been defined. @@ -61,8 +68,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, if not proceed: return - if save_impl.should_skip_serialization(model): - saving_utils.raise_model_input_error(model) + if save_traces: + if save_impl.should_skip_serialization(model): + saving_utils.raise_model_input_error(model) if not include_optimizer: orig_optimizer = model.optimizer @@ -77,7 +85,8 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, # the replica context is not available when calling `add_update()`, and thus # we use the default replica context here. with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access - save_lib.save(model, filepath, signatures, options) + with utils.keras_option_scope(save_traces): + save_lib.save(model, filepath, signatures, options) if not include_optimizer: model.optimizer = orig_optimizer diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index 97c8fc313f0..a2e74deaaeb 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.mixed_precision.experimental import autocast_variable +from tensorflow.python.keras.mixed_precision import autocast_variable from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import load as keras_load @@ -628,7 +628,9 @@ def _wrap_activity_regularizer(layer): return def_function.Function( layer._activity_regularizer, '{}_activity_regularizer'.format(layer.name), - input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())]) + input_signature=[ + tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx()) + ]) # pylint: enable=protected-access diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 3cb960d0971..12a3a7761b8 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -26,7 +26,6 @@ from __future__ import print_function import os import shutil -import sys from absl.testing import parameterized import numpy as np @@ -114,7 +113,7 @@ class GlobalLayerThatShouldFailIfNotAdded(keras.layers.Layer): @keras_parameterized.run_all_keras_modes -class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): +class TestSavedModelFormatAllModes(keras_parameterized.TestCase): def _save_model_dir(self, dirname='saved_model'): temp_dir = self.get_temp_dir() @@ -411,14 +410,16 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.evaluate(variables.variables_initializer(model.variables)) saved_model_dir = self._save_model_dir() - with self.captureWritesToStream(sys.stderr) as captured_logs: - model.save(saved_model_dir, save_format='tf') - loaded = keras_load.load(saved_model_dir) + # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save + # metadata warning. + # with self.captureWritesToStream(sys.stderr) as captured_logs: + model.save(saved_model_dir, save_format='tf') + loaded = keras_load.load(saved_model_dir) # Assert that saving does not log deprecation warnings # (even if it needs to set learning phase for compat reasons) - if context.executing_eagerly(): - self.assertNotIn('deprecated', captured_logs.contents()) + # if context.executing_eagerly(): + # self.assertNotIn('deprecated', captured_logs.contents()) input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) @@ -829,6 +830,14 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.evaluate(variables.variables_initializer(loaded.variables)) self.assertAllClose(model.predict(f), loaded.predict(f)) + +class TestSavedModelFormat(test.TestCase): + + def _save_model_dir(self, dirname='saved_model'): + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + return os.path.join(temp_dir, dirname) + def test_load_with_partially_failed_serialization(self): class BadCustomLayer(keras.layers.Layer): @@ -858,6 +867,48 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'call function was not serialized'): loaded.layer(inp) + def test_save_without_tracing(self): + + class DoNotTrace(keras.layers.Layer): + + def __init__(self): + super(DoNotTrace, self).__init__() + self.input_spec = keras.layers.InputSpec(shape=[None]) + self.built = True + + def call(self, inputs): + raise ValueError('I said do not trace') + + def get_config(self): + return {} + + root = keras.models.Sequential() + root.add(keras.layers.Input(shape=(3,))) + root.attached_layer = DoNotTrace() + + saved_model_dir = self._save_model_dir() + + # With the default settings, the call function is traced. + with self.assertRaisesRegex(ValueError, 'do not trace'): + root.save(saved_model_dir, save_format='tf') + + # When saving the config only, the layer call function should not be not + # traced. + root.save(saved_model_dir, save_format='tf', save_traces=False) + loaded = tf_load.load(saved_model_dir) + self.assertTrue(hasattr(loaded, 'attached_layer')) + + # This should raise an error when loaded without the custom object + loaded = keras_load.load(saved_model_dir) + with self.assertRaisesRegex(ValueError, 'Cannot call custom layer'): + loaded.attached_layer(constant_op.constant([1.])) + + # Try loading with the custom objects + with generic_utils.CustomObjectScope({'DoNotTrace': DoNotTrace}): + loaded = keras_load.load(saved_model_dir) + with self.assertRaisesRegex(ValueError, 'I said do not trace'): + loaded.attached_layer(constant_op.constant([1.])) + class TestLayerCallTracing(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/saving/saved_model/utils.py b/tensorflow/python/keras/saving/saved_model/utils.py index 5ed68e775d8..c3e470f6069 100644 --- a/tensorflow/python/keras/saving/saved_model/utils.py +++ b/tensorflow/python/keras/saving/saved_model/utils.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import itertools +import threading import types from tensorflow.python.eager import context @@ -25,6 +26,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.util import tf_decorator @@ -245,3 +247,27 @@ def remove_training_arg(index, args, kwargs): args.pop(index) else: kwargs.pop('training', None) + + +class SaveOptionsContext(threading.local): + + def __init__(self): + super(SaveOptionsContext, self).__init__() + self.save_traces = True + + +_save_options_context = SaveOptionsContext() + + +@tf_contextlib.contextmanager +def keras_option_scope(save_traces): + previous_value = _save_options_context.save_traces + try: + _save_options_context.save_traces = save_traces + yield + finally: + _save_options_context.save_traces = previous_value + + +def should_save_traces(): + return _save_options_context.save_traces diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 6d212e0cda3..fecf52e71b3 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -304,6 +304,7 @@ _thread_local_data = threading.local() _thread_local_data.model_type = None _thread_local_data.run_eagerly = None _thread_local_data.saved_model_format = None +_thread_local_data.save_kwargs = None @tf_contextlib.contextmanager @@ -383,7 +384,7 @@ def should_run_eagerly(): @tf_contextlib.contextmanager -def saved_model_format_scope(value): +def saved_model_format_scope(value, **kwargs): """Provides a scope within which the savde model format to test is `value`. The saved model format gets restored to its original value upon exiting the @@ -391,17 +392,21 @@ def saved_model_format_scope(value): Arguments: value: saved model format value + **kwargs: optional kwargs to pass to the save function. Yields: The provided value. """ - previous_value = _thread_local_data.saved_model_format + previous_format = _thread_local_data.saved_model_format + previous_kwargs = _thread_local_data.save_kwargs try: _thread_local_data.saved_model_format = value - yield value + _thread_local_data.save_kwargs = kwargs + yield finally: # Restore saved model format to initial value. - _thread_local_data.saved_model_format = previous_value + _thread_local_data.saved_model_format = previous_format + _thread_local_data.save_kwargs = previous_kwargs def get_save_format(): @@ -413,6 +418,15 @@ def get_save_format(): return _thread_local_data.saved_model_format +def get_save_kwargs(): + if _thread_local_data.save_kwargs is None: + raise ValueError( + 'Cannot call `get_save_kwargs()` outside of a ' + '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' + 'decorator.') + return _thread_local_data.save_kwargs or {} + + def get_model_type(): """Gets the model type that should be tested.""" if _thread_local_data.model_type is None: diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index 0ac8abcbe18..201340c3b4e 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -29,7 +29,6 @@ tf_py_test( tags = [ "no_pip", ], - tfrt_enabled = True, deps = [ ":get_config_samples", "//tensorflow/python:client_testlib", @@ -57,7 +56,6 @@ tpu_py_test( "automatic_outside_compilation_test.py", ], disable_experimental = True, - disable_mlir_bridge = False, python_version = "PY3", tags = ["no_oss"], deps = [ @@ -129,7 +127,6 @@ tf_py_test( name = "graph_util_test", srcs = ["graph_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -272,7 +269,6 @@ tf_py_test( "nomsan", # TODO(b/149948895): Re-enable. "notsan", # TODO(b/149948895): Re-enable. ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -305,7 +301,6 @@ cuda_py_test( name = "summary_ops_test", size = "small", srcs = ["summary_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -323,7 +318,6 @@ tf_py_test( name = "saved_model_test", size = "small", srcs = ["saved_model_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -370,7 +364,6 @@ tf_py_test( size = "small", srcs = ["serialization_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -403,7 +396,6 @@ tf_py_test( "no_windows", "nomac", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -432,7 +424,6 @@ tf_py_test( srcs = ["tracking_util_test.py"], python_version = "PY3", tags = ["notsan"], # b/74395663 - tfrt_enabled = True, deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/python:checkpoint_management", diff --git a/tensorflow/python/keras/tests/automatic_outside_compilation_test.py b/tensorflow/python/keras/tests/automatic_outside_compilation_test.py index 64733746776..ceedfa0a6b1 100644 --- a/tensorflow/python/keras/tests/automatic_outside_compilation_test.py +++ b/tensorflow/python/keras/tests/automatic_outside_compilation_test.py @@ -246,33 +246,42 @@ class AutoOutsideCompilationWithKerasTest(test.TestCase): def testSummaryWithCustomTrainingLoop(self): strategy = get_tpu_strategy() + writer = summary_ops_v2.create_file_writer_v2(self.summary_dir) with strategy.scope(): model = distribute_strategy_test.get_model() model.compile('sgd', 'mse') - writer = summary_ops_v2.create_file_writer_v2(self.summary_dir) - @def_function.function - def custom_function(dataset): + @def_function.function + def custom_function(dataset): - def _custom_step(features, labels): - del labels - logits = model(features) - with summary_ops_v2.always_record_summaries(), writer.as_default(): - scalar_summary_v2.scalar( - 'logits', - math_ops.reduce_sum(logits), - step=model.optimizer.iterations) - return logits + def _custom_step(features, labels): + del labels + logits = model(features) + with summary_ops_v2.always_record_summaries(), writer.as_default(): + scalar_summary_v2.scalar( + 'logits', + math_ops.reduce_sum(logits), + step=model.optimizer.iterations) + return logits - iterator = iter(dataset) - output = strategy.unwrap( - strategy.run(_custom_step, args=(next(iterator)))) - return output + iterator = iter(dataset) + output = strategy.unwrap( + strategy.run(_custom_step, args=(next(iterator)))) + return output - dataset = strategy.experimental_distribute_dataset( - distribute_strategy_test.get_dataset(strategy)) + dataset = strategy.experimental_distribute_dataset( + distribute_strategy_test.get_dataset(strategy)) - custom_function(dataset) + custom_function(dataset) + writer.close() + + event_files = file_io.get_matching_files_v2( + os.path.join(self.summary_dir, 'event*')) + events_count_dictionary = { + ('logits'): 0, + } + self.validate_recorded_sumary_file(event_files, events_count_dictionary, + 1) if __name__ == '__main__': diff --git a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py index 18c91f3b622..3463a7862bd 100644 --- a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py +++ b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py @@ -27,7 +27,6 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations @@ -460,16 +459,6 @@ class CheckpointingTests(keras_parameterized.TestCase): self.evaluate(root.save_counter)) # pylint: enable=cell-var-from-loop - def _get_checkpoint_name(self, name): - root = module.Module() - trackable_utils.add_variable( - root, name=name, shape=[1, 2], dtype=dtypes.float64) - (named_variable,), _, _ = trackable_utils._serialize_object_graph( - root, saveables_cache=None) - with ops.name_scope_v2("root/" + named_variable.name): - pass # Make sure we can use this as an op name if we prefix it. - return named_variable.name - @combinations.generate(combinations.combine(mode=["eager"])) def testAnonymousVarsInInit(self): diff --git a/tensorflow/python/keras/type/types.py b/tensorflow/python/keras/type/types.py index bf83670567c..77e52990fbe 100644 --- a/tensorflow/python/keras/type/types.py +++ b/tensorflow/python/keras/type/types.py @@ -176,9 +176,9 @@ class Layer(object): Attributes: name: The name of the layer (string). dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. + precision is used with a `tf.keras.mixed_precision.Policy`, this is + instead just the dtype of the layer's weights, as the computations are + done in a different dtype. updates: List of update ops of this layer. losses: List of losses added by this layer. trainable_weights: List of variables to be included in backprop. @@ -197,7 +197,6 @@ class Layer(object): if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed precision is used, layers may have different computation and variable dtypes. - See `tf.keras.mixed_precision.experimental.Policy` for details on layer - dtypes. + See `tf.keras.mixed_precision.Policy` for details on layer dtypes. """ pass diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 30b93fa95d8..59fd7235869 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -276,7 +276,6 @@ tf_py_test( size = "small", srcs = ["version_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":version_utils", "//tensorflow/python:client_testlib", @@ -290,7 +289,6 @@ tf_py_test( size = "small", srcs = ["tf_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":tf_utils", "//tensorflow/python:client_testlib", @@ -332,7 +330,6 @@ tf_py_test( "no_windows", # TODO: needs investigation on Windows "notsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -346,7 +343,6 @@ tf_py_test( size = "small", srcs = ["layer_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":layer_utils", "//tensorflow/python:client_testlib", @@ -360,7 +356,6 @@ tf_py_test( size = "small", srcs = ["np_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -374,7 +369,6 @@ tf_py_test( size = "small", srcs = ["kernelized_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":layer_utils", "//tensorflow/python:client_testlib", @@ -406,7 +400,6 @@ tf_py_test( size = "small", srcs = ["vis_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -420,7 +413,6 @@ tf_py_test( size = "small", srcs = ["conv_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -434,7 +426,6 @@ tf_py_test( size = "small", srcs = ["metrics_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/keras/utils/tf_inspect.py b/tensorflow/python/keras/utils/tf_inspect.py index dd13ea6c393..8f1b668539a 100644 --- a/tensorflow/python/keras/utils/tf_inspect.py +++ b/tensorflow/python/keras/utils/tf_inspect.py @@ -90,6 +90,11 @@ else: return _convert_maybe_argspec_to_fullargspec(getargspec(target)) +def currentframe(): + """TFDecorator-aware replacement for inspect.currentframe.""" + return _inspect.stack()[1][0] + + def getargspec(obj): """TFDecorator-aware replacement for `inspect.getargspec`. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index b105d1f987a..21c27ea74e4 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -20,7 +20,6 @@ tf_py_test( size = "small", srcs = ["as_string_op_test.py"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -34,7 +33,6 @@ tf_py_test( name = "attention_ops_test", size = "small", srcs = ["attention_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -54,7 +52,6 @@ tf_py_test( "nomsan", # TODO(b/161902335): Re-enable. "notsan", # TODO(b/161829717): Re-enable. ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_ops", @@ -69,7 +66,6 @@ tf_py_test( size = "small", srcs = ["base64_ops_test.py"], tags = ["nomac"], # b/35468214 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -84,7 +80,6 @@ tf_py_test( tf_py_test( name = "batch_scatter_ops_test", srcs = ["batch_scatter_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -105,7 +100,6 @@ tf_py_test( name = "bcast_ops_test", size = "small", srcs = ["bcast_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops_gen", "//tensorflow/python:client_testlib", @@ -168,7 +162,6 @@ cuda_py_test( size = "small", srcs = ["benchmark_test.py"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -181,7 +174,6 @@ cuda_py_test( cuda_py_test( name = "reduce_benchmark_test", srcs = ["reduce_benchmark_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -200,7 +192,6 @@ cuda_py_test( size = "small", srcs = ["bincount_op_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:bincount_ops", "//tensorflow/python:client_testlib", @@ -212,7 +203,6 @@ tf_py_test( name = "candidate_sampler_ops_test", size = "small", srcs = ["candidate_sampler_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:candidate_sampling_ops", @@ -227,7 +217,6 @@ tf_py_test( name = "checkpoint_ops_test", size = "medium", srcs = ["checkpoint_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:checkpoint_ops_gen", @@ -275,7 +264,6 @@ tf_py_test( "no_gpu", # b/127001953 "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:clip_ops", @@ -283,17 +271,27 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "collective_ops_test", size = "medium", srcs = ["collective_ops_test.py"], - tfrt_enabled = False, + tags = [ + "multi_and_single_gpu", + ], deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:collective_ops_gen", - "//tensorflow/python:framework_combinations", + "//tensorflow/python:collective_ops", + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/experimental/ops:testing", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:test_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "@absl_py//absl/testing:parameterized", ], ) @@ -302,17 +300,22 @@ tf_py_test( size = "medium", srcs = ["collective_ops_multi_worker_test.py"], python_version = "PY3", - tags = ["no_rocm"], - tfrt_enabled = False, + tags = [ + "no_oss_py38", #TODO(b/171435331) + "no_rocm", + "notsan", # TODO(b/171435192) + ], deps = [ "//tensorflow/python:collective_ops", "//tensorflow/python:constant_op", "//tensorflow/python:errors", + "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", ], ) @@ -320,7 +323,6 @@ tf_py_test( name = "conditional_accumulator_test", size = "small", srcs = ["conditional_accumulator_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -338,7 +340,6 @@ tf_py_test( name = "ctc_decoder_ops_test", size = "small", srcs = ["ctc_decoder_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -379,7 +380,9 @@ cuda_py_test( name = "cudnn_deterministic_ops_test", size = "small", srcs = ["cudnn_deterministic_ops_test.py"], - tfrt_enabled = True, + tags = [ + "no_cuda_asan", # TODO(b/171509035): re-enable. + ], xla_enable_strict_auto_jit = True, deps = [ ":cudnn_deterministic_base", @@ -390,7 +393,9 @@ cuda_py_test( name = "cudnn_deterministic_test", size = "small", srcs = ["cudnn_deterministic_test.py"], - tfrt_enabled = True, + tags = [ + "no_cuda_asan", # TODO(b/171509035): re-enable. + ], deps = [ ":cudnn_deterministic_base", ], @@ -400,7 +405,6 @@ cuda_py_test( name = "cumulative_logsumexp_test", size = "medium", srcs = ["cumulative_logsumexp_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -416,7 +420,6 @@ tf_py_test( name = "decode_csv_op_test", size = "small", srcs = ["decode_csv_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -431,7 +434,6 @@ tf_py_test( name = "decode_png_op_test", size = "small", srcs = ["decode_png_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -445,7 +447,6 @@ tf_py_test( name = "decode_bmp_op_test", size = "small", srcs = ["decode_bmp_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -459,7 +460,6 @@ tf_py_test( name = "decode_jpeg_op_test", srcs = ["decode_jpeg_op_test.py"], data = ["//tensorflow/core:image_testdata"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -473,7 +473,6 @@ tf_py_test( size = "small", srcs = ["decode_image_op_test.py"], data = ["//tensorflow/core:image_testdata"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -488,7 +487,6 @@ tf_py_test( name = "decode_raw_op_test", size = "small", srcs = ["decode_raw_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -502,7 +500,6 @@ tf_py_test( name = "decode_compressed_op_test", size = "small", srcs = ["decode_compressed_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -516,7 +513,6 @@ cuda_py_test( name = "determinant_op_test", size = "medium", srcs = ["determinant_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -529,7 +525,6 @@ tf_py_test( name = "draw_bounding_box_op_test", size = "small", srcs = ["draw_bounding_box_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -544,7 +539,6 @@ tf_py_test( name = "edit_distance_op_test", size = "small", srcs = ["edit_distance_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -576,7 +570,6 @@ tf_py_test( name = "fingerprint_op_test", size = "small", srcs = ["fingerprint_op_test.py"], - tfrt_enabled = True, deps = [ "//third_party/py/numpy", ], @@ -587,7 +580,6 @@ tf_py_test( size = "small", srcs = ["fractional_avg_pool_op_test.py"], shard_count = 5, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -604,7 +596,6 @@ tf_py_test( size = "small", srcs = ["fractional_max_pool_op_test.py"], shard_count = 5, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -620,7 +611,6 @@ tf_py_test( name = "identity_op_py_test", size = "small", srcs = ["identity_op_py_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -635,7 +625,6 @@ tf_py_test( name = "identity_n_op_py_test", size = "small", srcs = ["identity_n_op_py_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -650,7 +639,6 @@ cuda_py_test( name = "in_topk_op_test", size = "small", srcs = ["in_topk_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -663,7 +651,6 @@ tf_py_test( name = "record_input_test", size = "medium", srcs = ["record_input_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_ops", @@ -676,7 +663,6 @@ tf_py_test( name = "io_ops_test", size = "small", srcs = ["io_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:io_ops", @@ -688,7 +674,6 @@ tf_py_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -705,7 +690,6 @@ tf_py_test( tags = [ "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -718,7 +702,6 @@ cuda_py_test( name = "logging_ops_test", size = "small", srcs = ["logging_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", @@ -756,7 +739,6 @@ tf_py_test( name = "losses_test", size = "medium", srcs = ["losses_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -817,7 +799,9 @@ cuda_py_test( name = "matrix_solve_ls_op_test", size = "medium", srcs = ["matrix_solve_ls_op_test.py"], - tfrt_enabled = True, + tags = [ + "noasan", # TODO(b/337374867) fails with -fsanitize=null + ], deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -832,7 +816,6 @@ cuda_py_test( name = "matrix_square_root_op_test", size = "medium", srcs = ["matrix_square_root_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -858,7 +841,6 @@ cuda_py_test( name = "banded_triangular_solve_op_test", size = "small", srcs = ["banded_triangular_solve_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:linalg_ops", @@ -871,7 +853,6 @@ cuda_py_test( size = "medium", srcs = ["matrix_triangular_solve_op_test.py"], shard_count = 3, - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:linalg_ops", @@ -917,7 +898,6 @@ tf_py_test( name = "partitioned_variables_test", size = "small", srcs = ["partitioned_variables_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -935,7 +915,6 @@ tf_py_test( name = "priority_queue_test", size = "medium", srcs = ["priority_queue_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -955,7 +934,6 @@ cuda_py_test( # # TODO(b/128347673): Re-enable. tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -974,7 +952,6 @@ tf_py_test( name = "regex_replace_op_test", size = "small", srcs = ["regex_replace_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -988,7 +965,6 @@ tf_py_test( name = "regex_full_match_op_test", size = "small", srcs = ["regex_full_match_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1002,7 +978,6 @@ tf_py_test( name = "save_restore_ops_test", size = "small", srcs = ["save_restore_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client", @@ -1033,7 +1008,6 @@ tf_py_test( name = "sparse_add_op_test", size = "small", srcs = ["sparse_add_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -1050,7 +1024,6 @@ tf_py_test( name = "sparse_concat_op_test", size = "small", srcs = ["sparse_concat_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1065,7 +1038,6 @@ tf_py_test( name = "sparse_conditional_accumulator_test", size = "small", srcs = ["sparse_conditional_accumulator_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1080,7 +1052,6 @@ tf_py_test( name = "sparse_reorder_op_test", size = "small", srcs = ["sparse_reorder_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1096,7 +1067,6 @@ tf_py_test( name = "sparse_reshape_op_test", size = "small", srcs = ["sparse_reshape_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1112,7 +1082,6 @@ tf_py_test( name = "sparse_split_op_test", size = "small", srcs = ["sparse_split_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework", @@ -1125,7 +1094,6 @@ tf_py_test( name = "sparse_slice_op_test", size = "small", srcs = ["sparse_slice_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework", @@ -1139,7 +1107,6 @@ tf_py_test( name = "sparse_to_dense_op_py_test", size = "small", srcs = ["sparse_to_dense_op_py_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1153,7 +1120,6 @@ tf_py_test( name = "sparsemask_op_test", size = "small", srcs = ["sparsemask_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1166,7 +1132,6 @@ tf_py_test( name = "string_format_op_test", size = "small", srcs = ["string_format_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -1180,7 +1145,6 @@ tf_py_test( name = "string_join_op_test", size = "small", srcs = ["string_join_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:string_ops", @@ -1210,7 +1174,6 @@ tf_py_test( name = "string_bytes_split_op_test", size = "small", srcs = ["string_bytes_split_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1230,7 +1193,6 @@ tf_py_test( name = "string_length_op_test", size = "small", srcs = ["string_length_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -1242,7 +1204,6 @@ tf_py_test( name = "string_strip_op_test", size = "small", srcs = ["string_strip_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1257,7 +1218,6 @@ tf_py_test( name = "string_lower_op_test", size = "small", srcs = ["string_lower_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1272,7 +1232,6 @@ tf_py_test( name = "string_upper_op_test", size = "small", srcs = ["string_upper_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1287,7 +1246,6 @@ tf_py_test( name = "substr_op_test", size = "small", srcs = ["substr_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -1326,7 +1284,6 @@ tf_py_test( name = "summary_v1_ops_test", size = "small", srcs = ["summary_v1_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -1340,7 +1297,6 @@ tf_py_test( name = "summary_v1_tensor_op_test", size = "small", srcs = ["summary_v1_tensor_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1357,7 +1313,6 @@ tf_py_test( name = "template_test", size = "small", srcs = ["template_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -1376,7 +1331,6 @@ cuda_py_test( name = "template_mirrored_strategy_test", size = "small", srcs = ["template_mirrored_strategy_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:init_ops", @@ -1408,7 +1362,6 @@ tf_py_test( name = "unicode_script_op_test", size = "small", srcs = ["unicode_script_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1421,7 +1374,6 @@ cuda_py_test( name = "topk_op_test", size = "medium", srcs = ["topk_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1437,7 +1389,6 @@ cuda_py_test( name = "nth_element_op_test", size = "small", srcs = ["nth_element_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1453,7 +1404,6 @@ tf_py_test( name = "unicode_encode_op_test", size = "small", srcs = ["unicode_encode_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1472,7 +1422,6 @@ tf_py_test( name = "unicode_transcode_op_test", size = "small", srcs = ["unicode_transcode_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -1505,7 +1454,6 @@ tf_py_test( name = "unique_op_test", size = "small", srcs = ["unique_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1544,7 +1492,6 @@ tf_py_test( size = "small", srcs = ["variables_test.py"], tags = ["no_windows"], # b/133869052 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1567,7 +1514,6 @@ cuda_py_test( name = "where_op_test", size = "medium", srcs = ["where_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1580,7 +1526,6 @@ cuda_py_test( name = "cast_op_test", size = "small", srcs = ["cast_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1597,7 +1542,6 @@ cuda_py_test( size = "small", srcs = ["dense_update_ops_no_tsan_test.py"], tags = ["notsan"], - tfrt_enabled = True, # TODO (b/140294007): the test fails with XLA. xla_enable_strict_auto_jit = False, deps = [ @@ -1616,7 +1560,6 @@ cuda_py_test( srcs = ["diag_op_test.py"], shard_count = 6, tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1632,7 +1575,6 @@ tf_py_test( size = "small", srcs = ["reader_ops_test.py"], data = ["//tensorflow/core/lib/lmdb:lmdb_testdata"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -1651,7 +1593,6 @@ cuda_py_test( name = "aggregate_ops_test", size = "small", srcs = ["aggregate_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1665,7 +1606,6 @@ cuda_py_test( name = "argmax_op_test", size = "small", srcs = ["argmax_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", @@ -1719,7 +1659,6 @@ cuda_py_test( size = "small", srcs = ["inplace_ops_test.py"], shard_count = 10, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1749,7 +1688,6 @@ cuda_py_test( name = "batchtospace_op_test", size = "small", srcs = ["batchtospace_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -1763,7 +1701,6 @@ cuda_py_test( name = "betainc_op_test", size = "small", srcs = ["betainc_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1792,7 +1729,6 @@ cuda_py_test( name = "bias_op_deterministic_test", size = "medium", srcs = ["bias_op_deterministic_test.py"], - tfrt_enabled = True, deps = [ ":bias_op_base", ], @@ -1811,7 +1747,6 @@ cuda_py_test( name = "bitcast_op_test", size = "small", srcs = ["bitcast_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1824,7 +1759,6 @@ cuda_py_test( name = "check_ops_test", size = "small", srcs = ["check_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", @@ -1843,7 +1777,6 @@ cuda_py_test( name = "constant_op_test", size = "small", srcs = ["constant_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1859,7 +1792,6 @@ cuda_py_test( name = "constant_op_eager_test", size = "small", srcs = ["constant_op_eager_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1921,7 +1853,6 @@ tf_py_test( name = "control_flow_util_test", size = "small", srcs = ["control_flow_util_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", @@ -1935,7 +1866,6 @@ tf_py_test( name = "control_flow_util_v2_test", size = "small", srcs = ["control_flow_util_v2_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:cond_v2", @@ -1950,7 +1880,6 @@ cuda_py_test( name = "conv1d_test", size = "small", srcs = ["conv1d_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1963,7 +1892,6 @@ cuda_py_test( name = "conv1d_transpose_test", size = "small", srcs = ["conv1d_transpose_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -1977,7 +1905,6 @@ cuda_py_test( name = "conv2d_transpose_test", size = "small", srcs = ["conv2d_transpose_test.py"], - tfrt_enabled = True, # TODO(b/144432983): S32 convolutions should not be auto-clustered, only # crashes tests. @@ -1996,7 +1923,6 @@ cuda_py_test( name = "conv3d_backprop_filter_v2_grad_test", size = "small", srcs = ["conv3d_backprop_filter_v2_grad_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2011,7 +1937,6 @@ cuda_py_test( name = "cross_grad_test", size = "small", srcs = ["cross_grad_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2036,7 +1961,6 @@ cuda_py_test( name = "dense_update_ops_test", size = "small", srcs = ["dense_update_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2053,7 +1977,6 @@ cuda_py_test( size = "medium", srcs = ["depthtospace_op_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2068,7 +1991,6 @@ cuda_py_test( size = "medium", srcs = ["division_past_test.py"], tags = ["manual"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2083,7 +2005,7 @@ cuda_py_test( tags = [ "multi_and_single_gpu", ], - tfrt_enabled = False, # TODO(b/153089059): add support for complex128. + # TODO(b/153089059): TFRT: Add support for complex128. deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2099,7 +2021,6 @@ cuda_py_test( name = "dynamic_stitch_op_test", size = "small", srcs = ["dynamic_stitch_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_grad", @@ -2114,7 +2035,6 @@ cuda_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], - tfrt_enabled = True, # TODO(b/144432983): S32 convolutions should not be auto-clustered. xla_enable_strict_auto_jit = False, deps = [ @@ -2129,7 +2049,6 @@ cuda_py_test( name = "extract_volume_patches_op_test", size = "small", srcs = ["extract_volume_patches_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2168,7 +2087,6 @@ cuda_py_test( name = "gather_nd_op_test", size = "small", srcs = ["gather_nd_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -2184,7 +2102,6 @@ cuda_py_test( name = "gradient_correctness_test", size = "small", srcs = ["gradient_correctness_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2226,7 +2143,6 @@ cuda_py_test( size = "medium", srcs = ["linalg_ops_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2243,7 +2159,6 @@ cuda_py_test( name = "lrn_op_test", size = "medium", srcs = ["lrn_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2259,7 +2174,6 @@ cuda_py_test( name = "lu_op_test", size = "small", srcs = ["lu_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2294,7 +2208,6 @@ cuda_py_test( size = "small", srcs = ["manip_ops_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2324,7 +2237,6 @@ cuda_py_test( name = "morphological_ops_test", size = "small", srcs = ["morphological_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2338,7 +2250,6 @@ cuda_py_test( name = "numerics_test", size = "small", srcs = ["numerics_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2355,7 +2266,6 @@ cuda_py_test( size = "small", srcs = ["one_hot_op_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2368,7 +2278,6 @@ cuda_py_test( name = "stack_op_test", size = "small", srcs = ["stack_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2405,7 +2314,6 @@ cuda_py_test( name = "pad_op_test", size = "small", srcs = ["pad_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2418,7 +2326,6 @@ cuda_py_test( name = "padding_fifo_queue_test", size = "small", srcs = ["padding_fifo_queue_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2452,7 +2359,6 @@ cuda_py_test( name = "reduce_join_op_test", size = "small", srcs = ["reduce_join_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2502,7 +2408,6 @@ cuda_py_test( "no_gpu", "noguitar", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2516,7 +2421,6 @@ cuda_py_test( name = "relu_op_test", size = "small", srcs = ["relu_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2535,7 +2439,6 @@ cuda_py_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2548,7 +2451,6 @@ cuda_py_test( name = "reverse_sequence_op_test", size = "small", srcs = ["reverse_sequence_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2561,7 +2463,6 @@ cuda_py_test( name = "compare_and_bitpack_op_test", size = "small", srcs = ["compare_and_bitpack_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2574,7 +2475,6 @@ cuda_py_test( name = "scalar_test", size = "small", srcs = ["scalar_test.py"], - tfrt_enabled = True, # b/140221961: Invalid dims for operations xla_enable_strict_auto_jit = False, deps = [ @@ -2595,7 +2495,6 @@ cuda_py_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -2653,7 +2552,6 @@ cuda_py_test( name = "softsign_op_test", size = "small", srcs = ["softsign_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2667,7 +2565,6 @@ cuda_py_test( name = "spacetobatch_op_test", size = "small", srcs = ["spacetobatch_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -2687,7 +2584,6 @@ cuda_py_test( "no_windows", "no_windows_gpu", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2701,7 +2597,6 @@ tf_py_test( name = "sparse_serialization_ops_test", size = "small", srcs = ["sparse_serialization_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2716,7 +2611,6 @@ tf_py_test( name = "sparse_tensors_map_ops_test", size = "small", srcs = ["sparse_tensors_map_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -2733,7 +2627,6 @@ cuda_py_test( name = "sparse_tensor_dense_matmul_grad_test", size = "small", srcs = ["sparse_tensor_dense_matmul_grad_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework", @@ -2748,7 +2641,6 @@ cuda_py_test( name = "sparse_xent_op_test", size = "small", srcs = ["sparse_xent_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -2787,7 +2679,6 @@ cuda_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", @@ -2803,7 +2694,6 @@ cuda_py_test( name = "string_to_hash_bucket_op_test", size = "small", srcs = ["string_to_hash_bucket_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2816,7 +2706,6 @@ cuda_py_test( name = "string_to_number_op_test", size = "small", srcs = ["string_to_number_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2829,7 +2718,6 @@ cuda_py_test( name = "summary_v1_audio_op_test", size = "small", srcs = ["summary_v1_audio_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -2843,7 +2731,6 @@ cuda_py_test( name = "summary_v1_image_op_test", size = "small", srcs = ["summary_v1_image_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -2894,7 +2781,6 @@ cuda_py_test( size = "small", srcs = ["trace_op_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:math_ops", @@ -2925,7 +2811,6 @@ cuda_py_test( name = "variable_ops_test", size = "small", srcs = ["variable_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2943,7 +2828,6 @@ cuda_py_test( name = "xent_op_test", size = "small", srcs = ["xent_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -2960,7 +2844,6 @@ cuda_py_test( name = "zero_division_test", size = "medium", srcs = ["zero_division_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -2976,7 +2859,6 @@ cuda_py_test( tags = [ "no_gpu", # Flaky: b/80127739, b/127001953 ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2993,7 +2875,6 @@ cuda_py_test( size = "medium", srcs = ["atrous_convolution_test.py"], tags = ["manual"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3008,7 +2889,6 @@ cuda_py_test( name = "pool_test", size = "medium", srcs = ["pool_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3040,7 +2920,6 @@ cuda_py_test( name = "conv3d_transpose_test", size = "medium", srcs = ["conv3d_transpose_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3084,7 +2963,6 @@ cuda_py_test( shard_count = 3, # TODO(b/118842098): Re-enable this test in Kokoro. tags = ["no_oss"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3101,7 +2979,6 @@ tf_py_test( size = "medium", srcs = ["neon_depthwise_conv_op_test.py"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3118,7 +2995,6 @@ cuda_py_test( size = "medium", srcs = ["division_future_test.py"], tags = ["manual"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3130,7 +3006,6 @@ cuda_py_test( name = "pooling_ops_3d_test", size = "medium", srcs = ["pooling_ops_3d_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3165,7 +3040,6 @@ cuda_py_test( timeout = "long", srcs = ["rnn_test.py"], shard_count = 10, - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -3240,7 +3114,6 @@ cuda_py_test( tags = [ "no_oss", # Requires 4GB+ RAM ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3255,7 +3128,6 @@ cuda_py_test( size = "medium", srcs = ["sparse_matmul_op_test.py"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3291,7 +3163,6 @@ cuda_py_test( name = "sparse_tensor_dense_matmul_op_test", size = "medium", srcs = ["sparse_tensor_dense_matmul_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -3346,7 +3217,6 @@ cuda_py_test( name = "stage_op_test", size = "medium", srcs = ["stage_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3362,7 +3232,6 @@ cuda_py_test( size = "medium", srcs = ["map_stage_op_test.py"], tags = ["no_oss"], # b/124474135 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3378,7 +3247,6 @@ cuda_py_test( size = "medium", srcs = ["concat_op_test.py"], tags = ["no_windows"], # b/126916429 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -3400,7 +3268,6 @@ cuda_py_test( "nomsan", "notsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3414,7 +3281,6 @@ cuda_py_test( srcs = ["conv_ops_3d_test.py"], shard_count = 30, tags = ["no_cuda11"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3493,7 +3359,6 @@ cuda_py_test( size = "medium", srcs = ["embedding_ops_test.py"], shard_count = 20, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3536,7 +3401,6 @@ cuda_py_test( size = "medium", srcs = ["matrix_band_part_op_test.py"], shard_count = 20, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3610,7 +3474,6 @@ cuda_py_test( "no_oss", # b/117185141. "nomsan", # TODO(b/117236102): Re-enable in msan build. ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3634,7 +3497,6 @@ cuda_py_test( "no_windows_gpu", "nomsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3654,7 +3516,6 @@ cuda_py_test( "no_windows_gpu", "nomsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3683,7 +3544,6 @@ tf_py_test( name = "sets_test", size = "medium", srcs = ["sets_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:errors", "//tensorflow/python:framework", @@ -3702,7 +3562,6 @@ tf_py_test( size = "small", srcs = ["weights_broadcast_test.py"], shard_count = 3, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3722,7 +3581,6 @@ tf_py_test( srcs = ["metrics_test.py"], shard_count = 20, tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3743,7 +3601,6 @@ tf_py_test( name = "confusion_matrix_test", size = "small", srcs = ["confusion_matrix_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3759,7 +3616,6 @@ cuda_py_test( name = "bucketize_op_test", size = "medium", srcs = ["bucketize_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3773,7 +3629,6 @@ tf_py_test( size = "small", srcs = ["sparse_cross_op_test.py"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3786,7 +3641,6 @@ tf_py_test( name = "garbage_collection_test", size = "small", srcs = ["garbage_collection_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -3943,7 +3797,6 @@ cuda_py_test( srcs = ["tridiagonal_matmul_op_test.py"], shard_count = 10, tags = ["no_rocm"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index df17e5a3a39..60dcf7bc791 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -10,7 +10,6 @@ package( cuda_py_test( name = "batch_gather_op_test", srcs = ["batch_gather_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -24,7 +23,6 @@ cuda_py_test( name = "unstack_op_test", size = "small", srcs = ["unstack_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -39,7 +37,6 @@ cuda_py_test( name = "slice_op_test", size = "medium", srcs = ["slice_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 8eb5af399b4..4106ea9b166 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1628,6 +1628,22 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): axis=(axis - 4))) self.assertAllClose(fake_quantized, expected) + def testBadAxis(self): + input_tensor = [2.5, 2.5] + input_min = [0, 0] + input_max = [1, 1] + error_message_pattern = "Shape must be at least rank 11 but is rank 1" + # TODO(b/171260356): Eager mode and graph mode throw different error types + error = errors.InvalidArgumentError if context.executing_eagerly( + ) else ValueError + with self.assertRaisesRegex(error, error_message_pattern): + self.evaluate( + array_ops.quantize_and_dequantize_v2( + input=input_tensor, + input_min=input_min, + input_max=input_max, + axis=10)) + def testQuantizeDequantizeGrad(self): shape = (2, 2) max_threshold = 0 diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD index e7a34382355..5b318324d4c 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/BUILD +++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD @@ -24,7 +24,6 @@ tf_py_test( name = "resource_ops_test", size = "small", srcs = ["resource_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", "//tensorflow/python:boosted_trees_ops", @@ -40,7 +39,6 @@ tf_py_test( name = "prediction_ops_test", size = "small", srcs = ["prediction_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", "//tensorflow/python:array_ops", @@ -55,7 +53,6 @@ tf_py_test( name = "stats_ops_test", size = "medium", srcs = ["stats_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:boosted_trees_ops", @@ -72,7 +69,6 @@ tf_py_test( name = "training_ops_test", size = "small", srcs = ["training_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", "//tensorflow/python:array_ops", @@ -87,7 +83,6 @@ tf_py_test( name = "quantile_ops_test", size = "small", srcs = ["quantile_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", "//tensorflow/python:boosted_trees_ops", diff --git a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py index 95143cc233a..5c9a351e327 100644 --- a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py @@ -20,16 +20,22 @@ from __future__ import print_function import copy import os +import threading import time +from absl.testing import parameterized + from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import collective_ops @@ -46,6 +52,18 @@ def enable_collective_ops(cluster_resolver): context.context().enable_collective_ops(server_def) +def enable_collective_ops_with_barrier(cluster_resolver): + multi_process_runner.get_barrier().wait() + enable_collective_ops(cluster_resolver) + multi_process_runner.get_barrier().wait() + + +device_combination = ( + combinations.combine(device="CPU", communication="RING", required_gpus=0) + + combinations.combine( + device="GPU", communication=["RING", "NCCL"], required_gpus=1)) + + class CollectiveOpTest(test.TestCase): def testCheckHealth(self): @@ -60,7 +78,8 @@ class CollectiveOpTest(test.TestCase): "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1", ]: - context.context().check_collective_ops_peer_health(task) + context.context().check_collective_ops_peer_health( + task, timeout_in_ms=1000) except errors.UnavailableError: continue break @@ -76,12 +95,13 @@ class CollectiveOpTest(test.TestCase): def worker_fn(): enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) context.context().check_collective_ops_peer_health( - "/job:worker/replica:0/task:1",) + "/job:worker/replica:0/task:1", timeout_in_ms=1000) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr.start_single_process("worker", 0) - with self.assertRaises(errors.UnavailableError): + with self.assertRaises( + (errors.UnavailableError, errors.DeadlineExceededError)): mpr.join() def testCheckHealthPeerRestart(self): @@ -109,7 +129,7 @@ class CollectiveOpTest(test.TestCase): time.sleep(1) try: context.context().check_collective_ops_peer_health( - "/job:worker/replica:0/task:0",) + "/job:worker/replica:0/task:0", timeout_in_ms=1000) except errors.UnavailableError: pass except errors.FailedPreconditionError: @@ -126,7 +146,8 @@ class CollectiveOpTest(test.TestCase): def worker_fn(): enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) - context.context().check_collective_ops_peer_health("localhost:12345",) + context.context().check_collective_ops_peer_health( + "localhost:12345", timeout_in_ms=1000) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) @@ -135,5 +156,176 @@ class CollectiveOpTest(test.TestCase): mpr.join() +two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner( + multi_worker_test_base.create_cluster_spec(num_workers=2), + initializer=lambda: enable_collective_ops(cluster_resolver_lib. + TFConfigClusterResolver())) + + +@combinations.generate( + combinations.times( + combinations.combine( + mode="eager", num_workers=2, runner=two_worker_pool_runner), + device_combination)) +class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): + + def testAbortCommunication(self, device, communication): + if communication == "NCCL": + self.skipTest("b/171358086: cannot test multi worker NCCL") + dev0 = "/device:%s:0" % device + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops_with_barrier(cluster_resolver) + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + # First perform a normal all-reduce to complete the group and instance + # resolution. + with ops.device(dev0): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + if cluster_resolver.task_id == 1: + + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down") + + t = threading.Thread(target=abort_fn) + t.start() + + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + t.join() + + # Enable collective ops again in order to reset the collective executor. + enable_collective_ops_with_barrier(cluster_resolver) + with ops.device(dev0): + collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + def testAbortGroupParamsResolution(self, device, communication): + if communication == "NCCL": + self.skipTest("b/171358086: cannot test multi worker NCCL") + dev0 = "/device:%s:0" % device + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops_with_barrier(cluster_resolver) + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + if cluster_resolver.task_id == 1: + + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down") + + t = threading.Thread(target=abort_fn) + t.start() + + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + # This hangs on params resolution since we're only launching one + # collective for a group size of 2. + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, + instance_key) + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, + instance_key) + + t.join() + + # Enable collective ops again in order to reset the collective executor. + enable_collective_ops_with_barrier(cluster_resolver) + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key) + + def testAbortInstanceParamsResolution(self, device, communication): + if communication == "NCCL": + self.skipTest("b/171358086: cannot test multi worker NCCL") + dev0 = "/device:%s:0" % device + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops_with_barrier(cluster_resolver) + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + # First perform a normal all-reduce to complete the group resolution. + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key) + + # We use broadcast to test aborting instance resolution since only broadcast + # waits for the group. + + if cluster_resolver.task_id == 1: + + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down") + + t = threading.Thread(target=abort_fn) + t.start() + + # Use a different instance key to trigger another instance resolution. + instance_key = 101 + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + # This hangs on params resolution since we're only launching one + # collective for a group size of 2. + with ops.device(dev0): + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + + t.join() + + # Enable collective ops again in order to reset the collective executor. + enable_collective_ops_with_barrier(cluster_resolver) + # Reassign instance_key so that it's the same on each worker. + instance_key = 100 + with ops.device(dev0): + if cluster_resolver.task_id == 0: + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + else: + collective_ops.broadcast_recv((1,), dtypes.float32, group_size, + group_key, instance_key) + + if __name__ == "__main__": multi_process_runner.test_main() diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index 0e3e16179a6..fe558bcae64 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -24,13 +24,16 @@ import time from absl.testing import parameterized from tensorflow.python.compat import v2_compat +from tensorflow.python.data.experimental.ops import testing as dataset_testing +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.framework import combinations -from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops as _collective_ops from tensorflow.python.platform import test @@ -41,161 +44,239 @@ class CollectiveOpsV1(object): class CollectiveOpsV2(object): - all_reduce = _collective_ops.all_reduce_v2 - all_gather = _collective_ops.all_gather_v2 + + @staticmethod + def all_reduce(t, group_size, group_key, instance_key, *args, **kwargs): + group_size = array_ops.identity(group_size) + group_key = array_ops.identity(group_key) + instance_key = array_ops.identity(instance_key) + return _collective_ops.all_reduce_v2(t, group_size, group_key, instance_key, + *args, **kwargs) + + @staticmethod + def all_gather(t, group_size, group_key, instance_key, *args, **kwargs): + group_size = array_ops.identity(group_size) + group_key = array_ops.identity(group_key) + instance_key = array_ops.identity(instance_key) + return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key, + *args, **kwargs) + + +device_combination = ( + combinations.combine(device='CPU', communication='RING', required_gpus=0) + + combinations.combine( + device='GPU', communication=['RING', 'NCCL'], required_gpus=2)) @combinations.generate( - combinations.combine( - collective_ops=[ - combinations.NamedObject('v1', CollectiveOpsV1), - combinations.NamedObject('v2', CollectiveOpsV2) - ], - mode='eager')) + combinations.times( + combinations.combine( + collective_ops=[ + combinations.NamedObject('v1', CollectiveOpsV1), + combinations.NamedObject('v2', CollectiveOpsV2) + ], + mode='eager'), device_combination)) class CollectiveOpsTest(test.TestCase, parameterized.TestCase): def setUp(self): _setup_context() super().setUp() - def testReduce(self, collective_ops): + def testReduce(self, collective_ops, device, communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device @def_function.function - def run_all_reduce_1cpu(): - with ops.device('/device:CPU:0'): + def run_all_reduce_1device(): + with ops.device(dev0): in_value = constant_op.constant([1.]) group_size = 1 group_key = 1 instance_key = 1 - return collective_ops.all_reduce(in_value, group_size, group_key, - instance_key) + return collective_ops.all_reduce( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication) @def_function.function - def run_all_reduce_2cpus(): + def run_all_reduce_2devices(): in_value = constant_op.constant([1.]) group_size = 2 group_key = 2 instance_key = 2 collectives = [] - with ops.device('/device:CPU:0'): + with ops.device(dev0): collectives.append( - collective_ops.all_reduce(in_value, group_size, group_key, - instance_key)) - with ops.device('/device:CPU:1'): + collective_ops.all_reduce( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication)) + with ops.device(dev1): collectives.append( - collective_ops.all_reduce(in_value, group_size, group_key, - instance_key)) + collective_ops.all_reduce( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication)) return collectives - self.assertAllClose(run_all_reduce_1cpu(), [1.], rtol=1e-5, atol=1e-5) - for result in run_all_reduce_2cpus(): + self.assertAllClose(run_all_reduce_1device(), [1.], rtol=1e-5, atol=1e-5) + for result in run_all_reduce_2devices(): self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5) - def testGather(self, collective_ops): + def testGather(self, collective_ops, device, communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device @def_function.function - def run_all_gather_1cpu(): - with ops.device('/device:CPU:0'): + def run_all_gather_1device(): + with ops.device(dev0): in_value = constant_op.constant([1.]) group_size = 1 group_key = 1 instance_key = 1 - return collective_ops.all_gather(in_value, group_size, group_key, - instance_key) + return collective_ops.all_gather( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication) @def_function.function - def run_all_gather_2cpus(): + def run_all_gather_2devices(): in_value = constant_op.constant([1.]) group_size = 2 group_key = 2 instance_key = 2 collectives = [] - with ops.device('/device:CPU:0'): + with ops.device(dev0): collectives.append( - collective_ops.all_gather(in_value, group_size, group_key, - instance_key)) - with ops.device('/device:CPU:1'): + collective_ops.all_gather( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication)) + with ops.device(dev1): collectives.append( - collective_ops.all_gather(in_value, group_size, group_key, - instance_key)) + collective_ops.all_gather( + in_value, + group_size, + group_key, + instance_key, + communication_hint=communication)) return collectives - self.assertAllClose(run_all_gather_1cpu(), [1.], rtol=1e-5, atol=1e-5) - for result in run_all_gather_2cpus(): + self.assertAllClose(run_all_gather_1device(), [1.], rtol=1e-5, atol=1e-5) + for result in run_all_gather_2devices(): self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5) - def testInstanceKeyScopedUnderGroupKey(self, collective_ops): + def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device, + communication): + if device == 'GPU' and context.num_gpus() < 4: + self.skipTest('not enough GPU') + + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + dev2 = '/device:%s:2' % device + dev3 = '/device:%s:3' % device @def_function.function - def run_all_reduce_4cpus_same_instance_key(): + def run_all_reduce_4devices_same_instance_key(): # Use a common instance key for both groups. instance_key = 0 # We will create 2 groups each with 2 devices. group_size = 2 - # Group 0 comprises cpu:0 and cpu:1. + # Group 0 comprises dev0 and dev1. group0_key = 0 - # Group 1 comprises cpu:2 and cpu:3. + # Group 1 comprises dev2 and dev3. group1_key = 1 collectives = [] - with ops.device('/device:CPU:0'): + with ops.device(dev0): collectives.append( collective_ops.all_reduce( constant_op.constant(1.), group_size, group0_key, instance_key)) - with ops.device('/device:CPU:1'): + with ops.device(dev1): collectives.append( collective_ops.all_reduce( constant_op.constant(2.), group_size, group0_key, instance_key)) - with ops.device('/device:CPU:2'): + with ops.device(dev2): collectives.append( collective_ops.all_reduce( constant_op.constant(3.), group_size, group1_key, instance_key)) - with ops.device('/device:CPU:3'): + with ops.device(dev3): collectives.append( collective_ops.all_reduce( constant_op.constant(4.), group_size, group1_key, instance_key)) return collectives - results = run_all_reduce_4cpus_same_instance_key() + results = run_all_reduce_4devices_same_instance_key() self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5) self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5) self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5) - def testCollectiveGroupSizeOne(self, collective_ops): + def testCollectiveGroupSizeOne(self, collective_ops, device, communication): + if communication == 'NCCL': + self.skipTest('b/170672646: it crashes with NCCL and group size one') + dev0 = '/device:%s:0' % device + group_size = 1 group_key = 100 instance_key = 100 - in_value = [1, 2, 3, 4] + in_value = [1., 2., 3., 4.] in_tensor = constant_op.constant(in_value) - reduced_tensor = collective_ops.all_reduce(in_tensor, group_size, group_key, - instance_key) + with ops.device(dev0): + reduced_tensor = collective_ops.all_reduce( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) self.assertAllEqual(in_value, reduced_tensor.numpy()) - gathered_tensor = collective_ops.all_gather( - in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + gathered_tensor = collective_ops.all_gather( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) self.assertAllEqual(in_value, gathered_tensor.numpy()) - def testMultipleGroups(self, collective_ops): + def testMultipleGroups(self, collective_ops, device, communication): + if device == 'GPU' and context.num_gpus() < 4: + self.skipTest('not enough GPU') + num_elements = 4 @def_function.function def run_all_reduce(group_size, group_key): instance_key = group_key - input_value = [group_key for i in range(num_elements)] + input_value = [float(group_key) for i in range(num_elements)] collectives = [] for device_idx in range(group_size): - with ops.device('/CPU:{}'.format(device_idx)): + with ops.device('/{}:{}'.format(device, device_idx)): input_tensor = constant_op.constant(input_value) collectives.append( - collective_ops.all_reduce(input_tensor, group_size, group_key, - instance_key)) + collective_ops.all_reduce( + input_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication)) return collectives def run_and_assert(group_size, group_key): for reduced_tensor in run_all_reduce(group_size, group_key): self.assertAllEqual( - [group_key * group_size for i in range(num_elements)], + [float(group_key) * group_size for i in range(num_elements)], reduced_tensor.numpy()) run_and_assert(group_size=2, group_key=1) @@ -203,24 +284,29 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): @combinations.generate( - combinations.combine( - collective_op=[ - combinations.NamedObject('all_reduce', _collective_ops.all_reduce), - combinations.NamedObject('all_reduce_v2', - _collective_ops.all_reduce_v2), - combinations.NamedObject('all_gather', _collective_ops.all_gather), - combinations.NamedObject('all_gather_v2', - _collective_ops.all_gather_v2), - ], - mode='eager', - communication=['ring'])) + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def setUp(self): _setup_context() super().setUp() - def testAbortGroupParamsResolution(self, collective_op, communication): + def testAbortGroupParamsResolution(self, collective_op, device, + communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device group_size = 2 group_key = 100 instance_key = 100 @@ -236,11 +322,23 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): # This hangs on params resolution since we're only launching one # collective for a group size of 2. - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) # After abortion, subsequent collectives should fail immediately. with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) t.join() # Reset the context in order to reset the collective executor. @@ -248,7 +346,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): # After reset non-NCCL collectives should work. def collective_fn(): - for device in ['CPU:0', 'CPU:1']: + for device in [dev0, dev1]: with ops.device(device): collective_op( in_tensor, @@ -259,14 +357,17 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def_function.function(collective_fn)() - def testAbortInstanceParamsResolution(self, collective_op, communication): + def testAbortInstanceParamsResolution(self, collective_op, device, + communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device group_size = 2 group_key = 100 instance_key = 100 in_tensor = constant_op.constant([1.]) def collective_fn(): - for device in ['CPU:0', 'CPU:1']: + for device in [dev0, dev1]: with ops.device(device): collective_op( in_tensor, @@ -290,11 +391,23 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): # This hangs on params resolution since we're only launching one # collective for a group size of 2. - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) # After abortion, subsequent collectives should fail immediately. with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) context._reset_context() # pylint: disable=protected-access t.join() @@ -304,7 +417,9 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): # After reset non-NCCL collectives should work. def_function.function(collective_fn)() - def testAbortCommunication(self, collective_op, communication): + def testAbortCommunication(self, collective_op, device, communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device group_size = 2 group_key = 100 instance_key = 100 @@ -312,7 +427,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): # First perform a normal collective to finish resolution. def collective_fn(): - for device in ['CPU:0', 'CPU:1']: + for device in [dev0, dev1]: with ops.device(device): collective_op( in_tensor, @@ -333,11 +448,23 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): t.start() with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) # After abortion, subsequent collectives should fail immediately. with self.assertRaisesRegex(errors.UnavailableError, 'peer down'): - collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) # Reset the context in order to reset the collective executor. t.join() @@ -345,37 +472,214 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def_function.function(collective_fn)() +class OpCancellationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + _setup_context() + super().setUp() + + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorNotAbortIfNoCollective(self, collective_op, device, + communication): + # Do not abort if there's no active collective ops. There could be + # exceptions like EOF which we expect users to catch, aborting collective + # ops on all op errors intervenes with this workflow. + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + group_size = 2 + group_key = 100 + instance_key = 100 + dataset = dataset_ops.Dataset.from_tensors([1.]) + + @def_function.function + def collective_fn(in_tensor): + for device in [dev0, dev1]: + with ops.device(device): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + @def_function.function + def f(): + iterator = iter(dataset) + collective_fn(next(iterator)) + # This next(iterator) should raise EOF. + collective_fn(next(iterator)) + + with self.assertRaises(errors.OutOfRangeError): + f() + collective_fn(constant_op.constant([1.])) + + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorAbortWithCollective(self, collective_op, device, + communication): + # Abort v1 collective ops if there're active collective ops at the time of + # an op error. This is due to the inability to cancel collective ops, and op + # errors may cause running collective ops to hang. + dev0 = '/device:%s:0' % device + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + # Make the dataset sleep a while so that the collective is being executed + # when the EOF happens. + dataset = dataset_ops.Dataset.from_tensors([1.]).apply( + dataset_testing.sleep(sleep_microseconds=200)) + + @def_function.function + def f(): + # Launch a collective op that won't be able to finish to test abortion + # when other ops error. + with ops.device(dev0): + ret = collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + iterator = iter(dataset) + next(iterator) + # This should raise EOF. + next(iterator) + return ret + + with self.assertRaises(errors.OutOfRangeError): + f() + # Now collective ops is aborted, subsequent collective ops should fail with + # the previous error. + with self.assertRaises(errors.CancelledError): + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorNotAbortWithCollective(self, collective_op, device, + communication): + # Do not abort v2 collective ops even if there're active collective ops at + # the time of an op error. We rely cancellation to terminate active + # collective ops. + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + @def_function.function + def collective_fn(): + for device in [dev0, dev1]: + with ops.device(device): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + # Local params resolution cannot be cancelled yet, so we perform a normal + # collective so that the group is resolved. + collective_fn() + + # Make the dataset sleep a while so that the collective is being executed + # when the EOF happens. + dataset = dataset_ops.Dataset.from_tensors([1.]).apply( + dataset_testing.sleep(sleep_microseconds=200)) + + @def_function.function + def f(): + # Launch a collective op that won't be able to finish to test cancellation + # when other ops error. + with ops.device(dev0): + ret = collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + iterator = iter(dataset) + next(iterator) + # This should raise EOF. + next(iterator) + return ret + + with self.assertRaises(errors.OutOfRangeError): + f() + # Collective ops shouldn't be aborted and new collectives should be able to + # proceed. + collective_fn() + + @combinations.generate( - combinations.combine( - collective_op=[ - combinations.NamedObject('all_reduce', _collective_ops.all_reduce), - combinations.NamedObject('all_reduce_v2', - _collective_ops.all_reduce_v2), - combinations.NamedObject('all_gather', _collective_ops.all_gather), - combinations.NamedObject('all_gather_v2', - _collective_ops.all_gather_v2), - ], - mode='eager', - communication=['ring'])) + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) class TimeoutTest(test.TestCase, parameterized.TestCase): def setUp(self): _setup_context() super().setUp() - def testTimeout(self, collective_op, communication): - timeout = 4.5 + def testTimeout(self, collective_op, device, communication): + timeout = 1.5 @def_function.function def run(group_size, reported_group_size=None): group_key = 20 instance_key = 30 - tensor = [1, 2, 3, 4] + tensor = [1., 2., 3., 4.] results = [] if reported_group_size is None: reported_group_size = group_size for i in range(group_size): - with ops.device('/CPU:{}'.format(i)): + with ops.device('/{}:{}'.format(device, i)): input_data = constant_op.constant(tensor) result = collective_op( input_data, @@ -396,18 +700,20 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): elapsed = time.time() - start_time self.assertAllGreaterEqual(elapsed, timeout) - def testParamResolutionAfterTimeoutV2(self, collective_op, communication): + def testParamResolutionAfterTimeout(self, collective_op, device, + communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device timeout = 1.5 - group_key = 20 instance_key = 30 - input_data = constant_op.constant([1, 2, 3, 4]) + input_data = constant_op.constant([1., 2., 3., 4.]) # This timeout comes from param solution. with self.assertRaisesRegex( errors.DeadlineExceededError, 'Collective has timed out waiting for other workers'): - with ops.device('CPU:0'): + with ops.device(dev0): collective_op( input_data, group_size=2, @@ -418,28 +724,29 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): # We launch the second device after the first device times out. This is to # simulate the situation when other workers are slow and the timeout is - # short. Since the CPU:0 times out in the param resolution phase, CPU:1 - # should times out as well, but in the execute phase. - with self.assertRaisesRegex(errors.DeadlineExceededError, - 'Collective has timed out during execution'): - with ops.device('CPU:1'): + # short. It should error immediately. + with self.assertRaisesRegex( + errors.DeadlineExceededError, + 'Collective has timed out waiting for other workers'): + with ops.device(dev1): collective_op( input_data, group_size=2, group_key=group_key, instance_key=instance_key, - communication_hint=communication, - timeout=timeout) + communication_hint=communication) - def testExecutionAfterTimeoutV2(self, collective_op, communication): + def testExecutionAfterTimeout(self, collective_op, device, communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device timeout = 1.5 group_key = 20 instance_key = 30 - input_data = constant_op.constant([1, 2, 3, 4]) + input_data = constant_op.constant([1., 2., 3., 4.]) @def_function.function def run(): - for device in ['CPU:0', 'CPU:1']: + for device in [dev0, dev1]: with ops.device(device): collective_op( input_data, @@ -454,7 +761,7 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(errors.DeadlineExceededError, 'Collective has timed out during execution'): - with ops.device('CPU:0'): + with ops.device(dev0): collective_op( input_data, group_size=2, @@ -468,7 +775,7 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): # short. It should error immediately. with self.assertRaisesRegex(errors.DeadlineExceededError, 'Collective has timed out during execution'): - with ops.device('CPU:1'): + with ops.device(dev1): # No timeout. collective_op( input_data, @@ -480,13 +787,7 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): def _setup_context(): context._reset_context() - cpus = config.list_physical_devices('CPU') - config.set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) + test_util.set_logical_devices_to_at_least('CPU', 4) context.ensure_initialized() diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 70d7b2530a9..2011b3b4b45 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -960,6 +960,42 @@ class CondV2Test(test.TestCase): self.assertAllEqual(fn_with_cond(), 12.0) + def _CheckIteratedCosGradients(self, func): + + def _grad(f): + def _grad_function(primal): + with backprop.GradientTape() as tape: + tape.watch(primal) + primal_out = f(primal) + return tape.gradient(primal_out, primal) + return _grad_function + + f = func + one = constant_op.constant(1.) + for expected in [math_ops.cos, + lambda x: -math_ops.sin(x), + lambda x: -math_ops.cos(x), + math_ops.sin, + math_ops.cos]: + self.assertAllClose(expected(one), def_function.function(f)(one)) + f = _grad(f) + + def testIteratedGradientsCond(self): + def _func(x): + return cond_v2.cond_v2( + constant_op.constant(True), + lambda: math_ops.cos(array_ops.identity(x)), + lambda: math_ops.sin(array_ops.identity(x))) + self._CheckIteratedCosGradients(_func) + + def testIteratedGradientsCase(self): + def _func(x): + return cond_v2.indexed_case( + constant_op.constant(1), + [lambda: math_ops.sin(array_ops.identity(x)), + lambda: math_ops.cos(array_ops.identity(x))]) + self._CheckIteratedCosGradients(_func) + def testLowering(self): with ops.Graph().as_default() as g: with self.session(graph=g) as sess: @@ -1405,6 +1441,15 @@ class CondV2ContainerTest(test.TestCase): class CondV2ColocationGroupAndDeviceTest(test.TestCase): + def setUp(self): + super(CondV2ColocationGroupAndDeviceTest, self).setUp() + cpus = context.context().list_physical_devices("CPU") + context.context().set_logical_device_configuration( + cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration() + ]) + def testColocateWithBeforeCond(self): with ops.Graph().as_default() as g: with self.session(graph=g): @@ -1480,31 +1525,64 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): self.assertTrue(len(run_metadata.partition_graphs) >= 2) def testDeviceBeforeCond(self): - with ops.Graph().as_default() as g: - with self.session(graph=g): - - def fn(): - self.assertEqual("", constant_op.constant(3.0).op.device) - return test_ops.device_placement_op() + with context.eager_mode(): + def fn(): + cpu_zero_op = test_ops.device_placement_op() + self.assertEqual("/device:CPU:0", cpu_zero_op.device) + with ops.device("CPU:1"): + cpu_one_op = test_ops.device_placement_op() + self.assertEqual("/device:CPU:1", cpu_one_op.device) + return cpu_zero_op, cpu_one_op + @def_function.function + def _cond_wrapper(): with ops.device("/device:CPU:0"): - self.assertIn( - compat.as_bytes("CPU:0"), - self.evaluate(cond_v2.cond_v2(constant_op.constant(True), - fn, fn))) + return cond_v2.cond_v2(constant_op.constant(True), fn, fn) - def fn2(): - self.assertEqual("", constant_op.constant(3.0).op.device) - return test_ops.device_placement_op() + zero_expected, one_expected = self.evaluate(_cond_wrapper()) + self.assertIn(compat.as_bytes("CPU:0"), zero_expected) + self.assertIn(compat.as_bytes("CPU:1"), one_expected) - if test_util.is_gpu_available(): - with ops.device("/device:GPU:0"): - self.assertIn( - compat.as_bytes("GPU:0"), - self.evaluate(cond_v2.cond_v2(constant_op.constant(True), - fn2, fn2))) - else: - self.skipTest("Test requires a GPU to check GPU device placement.") + def fn2(): + self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device) + return test_ops.device_placement_op() + + @def_function.function + def _cond_wrapper2(): + with ops.device("/device:GPU:0"): + return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2) + + if test_util.is_gpu_available(): + self.assertIn(compat.as_bytes("GPU:0"), + self.evaluate(_cond_wrapper2())) + else: + self.skipTest("Test requires a GPU to check GPU device placement.") + + def testColocationBeforeCond(self): + with context.eager_mode(): + + def _fn(): + result = test_ops.device_placement_op() + self.assertIn("colocation_test_op", + result.op.colocation_groups()[0].decode()) + return result + + @def_function.function(autograph=False) + def _cond_wrapper(): + with ops.device("/device:CPU:0"): + op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op") + with ops.device("/device:CPU:1"): + op_on_cpu_1 = test_ops.device_placement_op( + name="colocation_test_op_1") + condition = constant_op.constant(True) + with ops.colocate_with(op_on_cpu_0.op): + zero_expected = cond_v2.cond_v2(condition, _fn, _fn) + with ops.colocate_with(op_on_cpu_1.op): + one_expected = cond_v2.cond_v2(condition, _fn, _fn) + return zero_expected, one_expected + zero_expected, one_expected = self.evaluate(_cond_wrapper()) + self.assertIn(compat.as_bytes("CPU:0"), zero_expected) + self.assertIn(compat.as_bytes("CPU:1"), one_expected) def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 54bbd2b2e9e..532dac1d85a 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -730,6 +730,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): g for g in run_metadata.partition_graphs if device_str in g.node[0].device ] + if not device_graphs: + return 0 self.assertLen(device_graphs, 1) switch_nodes = [ n for n in device_graphs[0].node @@ -759,7 +761,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): options = config_pb2.RunOptions(output_partition_graphs=True) sess.run( r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) - self.assertLen(run_metadata.partition_graphs, 2) # Check that the Switch for `arg` gets placed on CPU. self.assertEqual( self._count_matching_switch_nodes_on_device(run_metadata, "CPU", diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 96628a1a06a..d98942af52f 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -767,6 +767,14 @@ class MinMaxOpTest(test.TestCase): self._compare(x.astype(t), y.astype(t), use_gpu=False) self._compare(x.astype(t), y.astype(t), use_gpu=True) + def testNaNPropagation(self): + x = np.array([1., np.nan, 1., np.nan], dtype=np.float64) + y = np.array([1., 1., np.nan, np.nan], dtype=np.float64) + for t in [np.float16, np.float32, np.float64]: + with self.subTest(t=t): + self._compare(x.astype(t), y.astype(t), use_gpu=False) + self._compare(x.astype(t), y.astype(t), use_gpu=True) + def testDifferentShapes(self): x = np.random.rand(1, 3, 2) * 100. y = np.random.rand(2) * 100. # should broadcast diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index 4b0915b37e7..fc1c9af6946 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -11,7 +11,6 @@ cuda_py_test( name = "bijector_test", size = "small", srcs = ["bijector_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -49,7 +48,6 @@ cuda_py_test( name = "kullback_leibler_test", size = "small", srcs = ["kullback_leibler_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -62,7 +60,6 @@ cuda_py_test( name = "beta_test", size = "small", srcs = ["beta_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -268,7 +265,6 @@ cuda_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -285,7 +281,6 @@ cuda_py_test( name = "identity_bijector_test", size = "small", srcs = ["identity_bijector_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index a0d8bef327d..ba6df8ffb4f 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -197,7 +197,7 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) - self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.04, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) @test_util.run_without_tensor_float_32( diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 67d1adba3a9..097183d1025 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -27,7 +27,6 @@ cuda_py_test( name = "linear_operator_addition_test", size = "small", srcs = ["linear_operator_addition_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -66,7 +65,6 @@ cuda_py_test( name = "linear_operator_algebra_test", size = "small", srcs = ["linear_operator_algebra_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/linalg/sparse/BUILD b/tensorflow/python/kernel_tests/linalg/sparse/BUILD index 560ba7b2fd4..96ebc38ce5a 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/BUILD +++ b/tensorflow/python/kernel_tests/linalg/sparse/BUILD @@ -11,7 +11,6 @@ cuda_py_test( name = "conjugate_gradient_test", size = "medium", srcs = ["conjugate_gradient_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -29,7 +28,6 @@ cuda_py_test( size = "medium", srcs = ["csr_sparse_matrix_test.py"], main = "csr_sparse_matrix_test.py", - tfrt_enabled = True, deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -55,7 +53,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_grad_test.py"], main = "csr_sparse_matrix_grad_test.py", shard_count = 50, - tfrt_enabled = True, deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -67,7 +64,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"], main = "csr_sparse_matrix_dense_mat_mul_grad_test.py", shard_count = 50, - tfrt_enabled = True, deps = [ "//tensorflow/python/ops/linalg/sparse", ], @@ -79,7 +75,6 @@ cuda_py_test( srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"], main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py", shard_count = 50, - tfrt_enabled = True, deps = [ "//tensorflow/python/ops/linalg/sparse", ], diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index f42600bd334..b9d9f007167 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -587,7 +587,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t)) self.assertAllClose( - c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5) + c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5) @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMulTransposed(self): @@ -650,7 +650,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): self.assertAllEqual(c_t.shape, c_dense_t.shape) c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t)) self.assertAllClose( - c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5) + c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5) @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMulConjugate(self): diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index e5c46c76e2e..0e935dfe8c4 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -126,7 +126,6 @@ tf_py_test( tags = [ "no_pip", ], - tfrt_enabled = True, deps = [ ":descriptor_source_test_base", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index 0a618b7f555..a9d855a5a2b 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark @@ -112,12 +111,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): else: tol = 1e-14 # Tests that a ~= q*r. - a_recon = math_ops.matmul(q, r) + a_recon = test_util.matmul_without_tf32(q, r) self.assertAllClose(a_recon, a, rtol=tol, atol=tol) def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. - xx = math_ops.matmul(x, x, adjoint_a=True) + xx = test_util.matmul_without_tf32(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) if is_single: tol = 1e-5 diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index 35129a59b84..58ec1f09732 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -1,7 +1,7 @@ # Tests of TensorFlow kernels written using the Python API. -load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") package( default_visibility = ["//tensorflow:internal"], @@ -24,7 +24,6 @@ cuda_py_test( name = "parameterized_truncated_normal_op_test", size = "medium", srcs = ["parameterized_truncated_normal_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client", @@ -43,7 +42,9 @@ tf_py_test( name = "random_shuffle_queue_test", size = "small", srcs = ["random_shuffle_queue_test.py"], - tfrt_enabled = True, + tags = [ + "no_cuda_on_cpu_tap", # TODO(b/171060960) flakyly broken assertions + ], deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_ops", @@ -59,7 +60,6 @@ cuda_py_test( name = "multinomial_op_test", size = "small", srcs = ["multinomial_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -79,7 +79,6 @@ cuda_py_test( size = "medium", srcs = ["multinomial_op_big_test.py"], shard_count = 3, - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -98,7 +97,6 @@ cuda_py_test( name = "random_crop_test", size = "small", srcs = ["random_crop_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:random_ops", @@ -110,7 +108,6 @@ cuda_py_test( name = "random_ops_test", size = "medium", srcs = ["random_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -125,7 +122,6 @@ cuda_py_test( size = "medium", srcs = ["stateless_random_ops_test.py"], shard_count = 10, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -159,7 +155,6 @@ cuda_py_test( name = "random_grad_test", size = "small", srcs = ["random_grad_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -180,7 +175,6 @@ tf_py_test( srcs = ["random_binomial_test.py"], shard_count = 3, tags = ["no_oss"], - tfrt_enabled = True, deps = [ ":util", "//tensorflow/python:array_ops", @@ -197,7 +191,6 @@ cuda_py_test( name = "random_poisson_test", size = "medium", srcs = ["random_poisson_test.py"], - tfrt_enabled = True, deps = [ ":util", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 8bf5a08a358..6a0f40108a8 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -728,9 +728,7 @@ class MinReductionTest(test.TestCase): def _compareAll(self, x, reduction_axes): self._compare(x, reduction_axes, False, use_gpu=True) - self._compare(x, reduction_axes, False, use_gpu=False) self._compare(x, reduction_axes, True, use_gpu=True) - self._compare(x, reduction_axes, True, use_gpu=False) def testAxesType(self): for dtype in [dtypes.int64, dtypes.int32]: @@ -739,13 +737,12 @@ class MinReductionTest(test.TestCase): tf_v = self.evaluate(v) self.assertAllEqual(tf_v, 0) - @test_util.run_deprecated_v1 - def testInfinity(self): + def testSpecialValues(self): for dtype in [np.float32, np.float64]: - for special_value_x in [-np.inf, np.inf]: - for special_value_y in [-np.inf, np.inf]: - np_arr = np.array([special_value_x, special_value_y]).astype(dtype) - self._compareAll(np_arr, None) + for size in range(1, 4): + for arr in itertools.product([-np.inf, 1., np.nan, np.inf], + repeat=size): + self._compareAll(np.array(arr, dtype=dtype), None) def testFloatReduce3D(self): # Create a 3D array of floats and reduce across all possible @@ -847,9 +844,7 @@ class MaxReductionTest(test.TestCase): def _compareAll(self, x, reduction_axes): self._compare(x, reduction_axes, False, use_gpu=True) - self._compare(x, reduction_axes, False, use_gpu=False) self._compare(x, reduction_axes, True, use_gpu=True) - self._compare(x, reduction_axes, True, use_gpu=False) def testAxesType(self): for dtype in [dtypes.int64, dtypes.int32]: @@ -858,13 +853,12 @@ class MaxReductionTest(test.TestCase): tf_v = self.evaluate(v) self.assertAllEqual(tf_v, 0) - @test_util.run_deprecated_v1 - def testInfinity(self): + def testSpecialValues(self): for dtype in [np.float32, np.float64]: - for special_value_x in [-np.inf, np.inf]: - for special_value_y in [-np.inf, np.inf]: - np_arr = np.array([special_value_x, special_value_y]).astype(dtype) - self._compareAll(np_arr, None) + for size in range(1, 4): + for arr in itertools.product([-np.inf, 1., np.nan, np.inf], + repeat=size): + self._compareAll(np.array(arr, dtype=dtype), None) def testInt64Reduce3D(self): # Create a 3D array of int64s and reduce across all possible diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 44279a98a39..87fc2bec2e2 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -103,7 +103,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, v0 = resource_variable_ops.ResourceVariable(1.0) self.assertAllEqual(v0.numpy(), 1.0) - @test_util.disable_tfrt("b/169375363: error code support") def testReadVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -199,7 +198,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, value, _ = sess.run([v, v.assign_add(1.0)]) self.assertAllEqual(value, 0.0) - @test_util.disable_tfrt("b/169375363: error code support") def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -750,7 +748,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(v.handle.op.colocation_groups(), v.initializer.inputs[1].op.colocation_groups()) - @test_util.disable_tfrt("b/169375363: error code support") def testCountUpTo(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(0, name="upto") @@ -758,7 +755,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaises(errors.OutOfRangeError): v.count_up_to(1) - @test_util.disable_tfrt("b/169375363: error code support") def testCountUpToFunction(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(0, name="upto") @@ -857,7 +853,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, variable_def=other_v_def) self.assertIsNotNone(other_v_prime._cached_value) - @test_util.disable_tfrt("b/169375363: error code support") def testVariableDefInitializedInstances(self): with ops.Graph().as_default(), self.cached_session(): v_def = resource_variable_ops.ResourceVariable( @@ -979,7 +974,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.evaluate(assign_without_read) self.assertEqual(0.0, self.evaluate(v.value())) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes @test_util.run_v1_only("b/120545219") def testDestroyResource(self): @@ -1006,7 +1000,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, [assign], feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) - @test_util.disable_tfrt("b/169375363: error code support") def testAssignDifferentShapesEagerNotAllowed(self): with context.eager_mode(): with variable_scope.variable_scope("foo"): @@ -1069,7 +1062,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, .batch_scatter_update(batch_slices2), [[1, 3], [2, 3]]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testInitValueWrongShape(self): with self.assertRaisesWithPredicateMatch( @@ -1088,7 +1080,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(v.dtype, w.dtype) # TODO(alive): get caching to work in eager mode. - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_deprecated_v1 def testCachingDevice(self): with ops.device("/job:server/task:1"): @@ -1105,7 +1096,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaises(ValueError): _ = w.value().op.get_attr("_class") - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_deprecated_v1 def testSharedName(self): with self.cached_session(): @@ -1164,7 +1154,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, v.initializer.run(feed_dict={v.initial_value: 3.0}) self.assertEqual(3.0, v.value().eval()) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_v1_only("b/120545219") def testControlFlowInitialization(self): """Expects an error if an initializer is in a control-flow scope.""" @@ -1252,7 +1241,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(1, v1.read_value().numpy()) self.assertEqual(2, v2.read_value().numpy()) - @test_util.disable_tfrt("b/169375363: error code support") def testDestruction(self): with context.eager_mode(): var = resource_variable_ops.ResourceVariable(initial_value=1.0, @@ -1340,7 +1328,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") @@ -1350,7 +1337,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaisesRegex(Exception, r"shape.*2.*3"): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) - @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testAssignIncompatibleShape(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) diff --git a/tensorflow/python/kernel_tests/scan_ops_test.py b/tensorflow/python/kernel_tests/scan_ops_test.py index 2a3021f9821..b0161b8d232 100644 --- a/tensorflow/python/kernel_tests/scan_ops_test.py +++ b/tensorflow/python/kernel_tests/scan_ops_test.py @@ -105,6 +105,15 @@ class CumsumTest(test.TestCase): axis = constant_op.constant(0, axis_dtype) tf_out = math_ops.cumsum(x, axis).eval() + @test_util.run_deprecated_v1 + def testNaN(self): + for dtype in (np.float16, np.float32, np.float64): + for nan_idx in range(0, 5): + x = np.arange(1, 6).reshape([5]).astype(dtype) + x[nan_idx] = np.nan + for axis in (-1, 0): + self._compareAll(x, axis) + @test_util.run_deprecated_v1 def test1D(self): for dtype in self.valid_dtypes: @@ -229,6 +238,15 @@ class CumprodTest(test.TestCase): axis = constant_op.constant(0, axis_dtype) tf_out = math_ops.cumprod(x, axis).eval() + @test_util.run_deprecated_v1 + def testNaN(self): + for dtype in (np.float16, np.float32, np.float64): + for nan_idx in range(0, 5): + x = np.arange(1, 6).reshape([5]).astype(dtype) + x[nan_idx] = np.nan + for axis in (-1, 0): + self._compareAll(x, axis) + @test_util.run_deprecated_v1 def test1D(self): for dtype in self.valid_dtypes: diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index a895827d5fe..b17a8f02594 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -102,11 +102,13 @@ class TopKTest(test.TestCase): self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]]) def testTop3(self): - k = 5 - inputs = np.random.permutation(np.linspace(0, 100, 6140, dtype=np.float64)) - indices = np.argsort(-inputs)[:k] - values = -np.sort(-inputs)[:k] - self._validateTopK(inputs, k, values, indices) + for k in range(3, 11, 2): + for dim in range(512, 12288, 512): + inputs = np.random.permutation( + np.linspace(0, 100, dim, dtype=np.float64)) + indices = np.argsort(-inputs)[:k] + values = -np.sort(-inputs)[:k] + self._validateTopK(inputs, k, values, indices) def testTop1AllNan(self): inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]] diff --git a/tensorflow/python/kernel_tests/v1_compat_tests/BUILD b/tensorflow/python/kernel_tests/v1_compat_tests/BUILD index ac04803ba3b..bd9c02d8101 100644 --- a/tensorflow/python/kernel_tests/v1_compat_tests/BUILD +++ b/tensorflow/python/kernel_tests/v1_compat_tests/BUILD @@ -12,7 +12,6 @@ tf_py_test( name = "identity_op_py_test", size = "small", srcs = ["identity_op_py_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", @@ -24,7 +23,6 @@ cuda_py_test( name = "scatter_nd_ops_test", size = "small", srcs = ["scatter_nd_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:state_ops", @@ -37,7 +35,6 @@ cuda_py_test( name = "session_ops_test", size = "small", srcs = ["session_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 4b9c1fed916..0402e129c19 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -171,6 +171,27 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertAllEqual(fnWithLoop(), 4.0) + def testDeviceLabelsInherited(self): + def _LoopBody(i, y): + result = math_ops.cos(y) + self.assertIn("CPU:10", result.device) + with ops.device("CPU:11"): + result = array_ops.identity(result) + self.assertIn("CPU:11", result.device) + return i + 1, result + + @def_function.function + def _FunctionWithWhileLoop(): + x = constant_op.constant(1.) + with ops.device("CPU:10"): + _, z = while_loop_v2( + lambda i, _: i < 2, + _LoopBody, + [0, x]) + return z + # The test assertion runs at trace time. + _FunctionWithWhileLoop.get_concrete_function() + def testExternalControlDependencies(self): with ops.Graph().as_default(), self.test_session(): v = variables.Variable(1.) diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index c50a0ff246c..31def39a98e 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -387,6 +387,9 @@ PyArray_Descr NPyBfloat16_Descr = { nullptr, // fields nullptr, // names &NPyBfloat16_ArrFuncs, // f + nullptr, // metadata + nullptr, // c_metadata + -1, // hash }; // Registered numpy type ID. Global variable populated by the registration code. diff --git a/tensorflow/python/module/BUILD b/tensorflow/python/module/BUILD index 1b93d8cf8b6..1bae3ce21a1 100644 --- a/tensorflow/python/module/BUILD +++ b/tensorflow/python/module/BUILD @@ -23,13 +23,21 @@ py_library( tf_py_test( name = "module_test", srcs = ["module_test.py"], - tfrt_enabled = True, deps = [ ":module", "//tensorflow/python:client_testlib", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:tf2", + "//tensorflow/python:type_spec", "//tensorflow/python:variables", - "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/distribute:ps_values", "//tensorflow/python/distribute:tpu_values", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/module/module.py b/tensorflow/python/module/module.py index 8878db4ebe2..ca42de89b95 100644 --- a/tensorflow/python/module/module.py +++ b/tensorflow/python/module/module.py @@ -157,7 +157,7 @@ class Module(tracking.AutoTrackable): name) followed by variables from all submodules recursively (breadth first). """ - return tuple(self._flatten(predicate=_is_variable)) + return tuple(self._flatten(predicate=_is_variable, expand_composites=True)) @property def trainable_variables(self): @@ -172,7 +172,8 @@ class Module(tracking.AutoTrackable): name) followed by variables from all submodules recursively (breadth first). """ - return tuple(self._flatten(predicate=_is_trainable_variable)) + return tuple( + self._flatten(predicate=_is_trainable_variable, expand_composites=True)) @property def submodules(self): @@ -202,7 +203,8 @@ class Module(tracking.AutoTrackable): recursive=True, predicate=None, attribute_traversal_key=None, - with_path=False): + with_path=False, + expand_composites=False): """Flattened attribute values in sorted order by attribute name. Modules are flattened by first walking their attributes in name order. @@ -247,6 +249,8 @@ class Module(tracking.AutoTrackable): as the object itself. If `with_path` is `True` then leaves will not be de-duplicated (e.g. if the same leaf instance is reachable via multiple modules then it will be yielded multiple times with different paths). + expand_composites: If true, then composite tensors are expanded into their + component tensors. Returns: Flat generator for leaves of the current module and optionally all @@ -261,7 +265,8 @@ class Module(tracking.AutoTrackable): predicate=predicate, attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES, attribute_traversal_key=attribute_traversal_key, - with_path=with_path) + with_path=with_path, + expand_composites=expand_composites) @classmethod def with_name_scope(cls, method): @@ -326,6 +331,7 @@ def _flatten_module(module, attribute_traversal_key, attributes_to_ignore, with_path, + expand_composites, module_path=(), seen=None): """Implementation of `flatten`.""" @@ -341,7 +347,8 @@ def _flatten_module(module, prop = module_dict[key] try: - leaves = nest.flatten_with_tuple_paths(prop) + leaves = nest.flatten_with_tuple_paths( + prop, expand_composites=expand_composites) except Exception as cause: # pylint: disable=broad-except six.raise_from( ValueError( @@ -376,6 +383,7 @@ def _flatten_module(module, attribute_traversal_key=attribute_traversal_key, attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access with_path=with_path, + expand_composites=expand_composites, module_path=submodule_path, seen=seen) diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py index e15bc734230..2d1b1627655 100644 --- a/tensorflow/python/module/module_test.py +++ b/tensorflow/python/module/module_test.py @@ -31,8 +31,10 @@ from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as distributed_values from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.framework import type_spec from tensorflow.python.module import module from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -260,6 +262,37 @@ class VariableTrackingTest(test_util.TensorFlowTestCase): m.c = aggregating self.assertEqual(m.variables, (mirrored, tpu, aggregating)) + def test_composite_variable(self): + + class Spec(type_spec.TypeSpec): + + value_type = property(lambda self: CompositeVariable) + + def _component_specs(self): + pass + + def _serialize(self): + pass + + def _to_components(self, value): + return value._variables + + def _from_components(self, variable_list): + return CompositeVariable(variable_list) + + class CompositeVariable(composite_tensor.CompositeTensor): + + def __init__(self, variable_list): + self._variables = variable_list + + @property + def _type_spec(self): + return Spec() + + m = module.Module() + m.a = CompositeVariable([variables.Variable(1.), variables.Variable(2.)]) + self.assertAllEqual(m.variables, m.a._variables) + class ModuleTrackingTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 8d03133d266..05445edc669 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -64,6 +64,9 @@ from tensorflow.python.framework.test_combinations import * from tensorflow.python.util.tf_decorator import make_decorator from tensorflow.python.util.tf_decorator import unwrap +from tensorflow.python.distribute.parameter_server_strategy_v2 import * +from tensorflow.python.distribute.coordinator.cluster_coordinator import * + tf_export('__internal__.decorator.make_decorator', v1=[])(make_decorator) tf_export('__internal__.decorator.unwrap', v1=[])(unwrap) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index b8c59bcbb6c..07edd54b494 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad): # For axis 0 gathers, build an appropriately shaped IndexedSlices. if axis_static == 0: if context.executing_eagerly(): - with ops.device("/cpu:0"): + with ops.device(indices_size.device): params_tail_shape = array_ops.identity(params_shape)[1:] else: params_tail_shape = params_shape[1:] diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py index 6afe923d795..f2378995597 100644 --- a/tensorflow/python/ops/collective_ops.py +++ b/tensorflow/python/ops/collective_ops.py @@ -177,9 +177,8 @@ def all_gather_v2(t, Returns: An Op implementing the distributed operation. """ - return gen_collective_ops.collective_gather( + return gen_collective_ops.collective_gather_v2( t, - shape=[0], group_size=group_size, group_key=group_key, instance_key=instance_key, diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 5bdd2494e91..75130fcd8a7 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -26,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.eager import backprop_util +from tensorflow.python.eager import function from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op @@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name return [None] + outputs +def _run_as_function_for_tape_gradients(make_op, cond_inputs): + """Fix higher-order tape gradients by wrapping `make_op` in a function.""" + # GradientTapes created inside a function currently don't work well with + # un-wrapped control flow ops in that same function. Wrapping in an extra + # layer of intermediate function means we run extra logic in the function + # gradient code to record the correct intermediates on the tape. + # + # The function attribute inputs to cond/case ops are not hashable, so we pass + # everything as a capture to bypass defun's caching. + if (gradients_util.PossibleTapeGradientTypes(cond_inputs) + == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER + # We only need one function between the tape and the cond; if we've + # already wrapped once, we stop wrapping to avoid infinite recursion. + and not (ops.get_default_graph().building_function + and "cond_gradient_wrapper" in ops.get_default_graph().name)): + + op = None + def _run_make_and_extract_op(): + # Post-processing happens on the cond op, not the function call op. + nonlocal op + tensors = make_op() + op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable + return tensors + + return op, function.defun_with_attributes( + _run_make_and_extract_op, + attributes=dict(func_name="cond_gradient_wrapper"))() + else: + return _get_op_and_outputs(make_op()) + + def _build_cond(pred, true_graph, false_graph, @@ -268,16 +300,17 @@ def _build_cond(pred, else: op_fn = gen_functional_ops.stateless_if - tensors = op_fn( - pred, - cond_inputs, [t.dtype for t in true_graph.outputs], - util.create_new_tf_function(true_graph), - util.create_new_tf_function(false_graph), - output_shapes=_get_output_shapes(true_graph.outputs, - false_graph.outputs), - name=name) + def make_op(): + return op_fn( + pred, + cond_inputs, [t.dtype for t in true_graph.outputs], + util.create_new_tf_function(true_graph), + util.create_new_tf_function(false_graph), + output_shapes=_get_output_shapes(true_graph.outputs, + false_graph.outputs), + name=name) + if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs) - if_op, tensors = _get_op_and_outputs(tensors) # `if_op` is None if this is a `StatelessIf` op with no outputs. if if_op is not None: if_op._true_graph = true_graph @@ -1156,14 +1189,16 @@ def _build_case(branch_index, # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): - tensors = op_fn( - branch_index, - case_inputs, [t.dtype for t in branch_graphs[0].outputs], - [util.create_new_tf_function(g) for g in branch_graphs], - output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), - name=name) - case_op, tensors = _get_op_and_outputs(tensors) + def _make_op(): + return op_fn( + branch_index, + case_inputs, [t.dtype for t in branch_graphs[0].outputs], + [util.create_new_tf_function(g) for g in branch_graphs], + output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), + name=name) + case_op, tensors = _run_as_function_for_tape_gradients( + _make_op, case_inputs) if case_op is not None: util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 0bbee785641..6f1bd352e2b 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -28,11 +28,11 @@ from tensorflow.python.framework import ops from tensorflow.python.framework.func_graph import FuncGraph from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_v2_func_graphs +from tensorflow.python.util import keras_deps from tensorflow.python.util import tf_contextlib _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None -_KERAS_LAYER_CONTEXT_FUNCTION = None _DISABLE_LOWER_USING_SWITCH_MERGE = False @@ -242,18 +242,11 @@ def _is_tpu_strategy(strategy): strategy.__class__.__name__.startswith("TPUStrategy")) -def _register_keras_layer_context_function(func): - global _KERAS_LAYER_CONTEXT_FUNCTION - # TODO(scottzhu): Disable duplicated inject once keras is moved to - # third_party/py/keras. - _KERAS_LAYER_CONTEXT_FUNCTION = func - - def _is_building_keras_layer(): # TODO(srbs): Remove this function when we no long support session with Keras. - global _KERAS_LAYER_CONTEXT_FUNCTION - if _KERAS_LAYER_CONTEXT_FUNCTION is not None: - return _KERAS_LAYER_CONTEXT_FUNCTION().layer is not None + keras_call_context_function = keras_deps.get_call_context_function() + if keras_call_context_function: + return keras_call_context_function().layer is not None else: return False diff --git a/tensorflow/python/ops/control_flow_v2_func_graphs.py b/tensorflow/python/ops/control_flow_v2_func_graphs.py index 23edd712797..edf14bc9755 100644 --- a/tensorflow/python/ops/control_flow_v2_func_graphs.py +++ b/tensorflow/python/ops/control_flow_v2_func_graphs.py @@ -22,43 +22,39 @@ from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops -class CondBranchFuncGraph(func_graph.FuncGraph): +class ControlFlowFuncGraph(func_graph.FuncGraph): + """Contains control flow-specific FuncGraph logic.""" + + def __init__(self, *args, **kwargs): + super(ControlFlowFuncGraph, self).__init__(*args, **kwargs) + outer_graph = self.outer_graph + # Unlike tf.function, control flow FuncGraphs are generally created one per + # op. This means hard-coding any outer device scopes in the body (rather + # than inspecting the call-time placement of the control flow op) makes + # sense. + self._device_function_stack = outer_graph._device_function_stack.copy() # pylint: disable=protected-access + self.is_control_flow_graph = True + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) + + +class CondBranchFuncGraph(ControlFlowFuncGraph): """FuncGraph for branches of tf.cond(). This is used to distinguish cond branches from other functions. """ - def __init__(self, *args, **kwargs): - super(CondBranchFuncGraph, self).__init__(*args, **kwargs) - self.is_control_flow_graph = True - if ops.executing_eagerly_outside_functions(): - func_graph.override_func_graph_name_scope( - self, self.outer_graph.get_name_scope()) - -class WhileCondFuncGraph(func_graph.FuncGraph): +class WhileCondFuncGraph(ControlFlowFuncGraph): """FuncGraph for the condition of tf.while_loop(). This is used to distinguish while conditions from other functions. """ - def __init__(self, *args, **kwargs): - super(WhileCondFuncGraph, self).__init__(*args, **kwargs) - self.is_control_flow_graph = True - if ops.executing_eagerly_outside_functions(): - func_graph.override_func_graph_name_scope( - self, self.outer_graph.get_name_scope()) - -class WhileBodyFuncGraph(func_graph.FuncGraph): +class WhileBodyFuncGraph(ControlFlowFuncGraph): """FuncGraph for the body of tf.while_loop(). This is used to distinguish while bodies from other functions. """ - - def __init__(self, *args, **kwargs): - super(WhileBodyFuncGraph, self).__init__(*args, **kwargs) - self.is_control_flow_graph = True - if ops.executing_eagerly_outside_functions(): - func_graph.override_func_graph_name_scope( - self, self.outer_graph.get_name_scope()) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 4d4df0ffa48..c356e82ac1f 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -24,6 +24,7 @@ import contextlib from six.moves import xrange, zip # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python import pywrap_tfe from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context @@ -1007,3 +1008,15 @@ def _AggregatedGrads(grads, # out_grads[i] is [], thus its aggregation is simply None. out_grads[i] = None return out_grads + + +# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are +# unfortunately too slow to use here. +POSSIBLE_GRADIENT_TYPES_NONE = 0 +POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 +POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 + + +def PossibleTapeGradientTypes(tensors): + """Determines whether and how `args` may require tape gradients.""" + return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 8a767597f00..75e66d8f513 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2964,6 +2964,12 @@ def decode_image(contents, function to `False`, in which case the op will return 3-dimensional tensors and will truncate animated GIF files to the first frame. + NOTE: If the first frame of an animated GIF does not occupy the entire + canvas (maximum frame width x maximum frame height), then it fills the + unoccupied areas (in the first frame) with zeros (black). For frames after the + first frame that does not occupy the entire canvas, it uses the previous + frame to fill the unoccupied areas. + Args: contents: 0-D `string`. The encoded image bytes. channels: An optional `int`. Defaults to `0`. Number of color channels for diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 6a651bbdcce..320facf5afa 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -2007,7 +2007,8 @@ class CentralCropTest(test_util.TensorFlowTestCase): self.assertTrue(y.op.name.startswith("central_crop")) -class PadToBoundingBoxTest(test_util.TensorFlowTestCase): +class PadToBoundingBoxTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def _PadToBoundingBox(self, x, offset_height, offset_width, target_height, target_width, use_tensor_inputs): @@ -2172,7 +2173,10 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): "inner 3 dims of \\'image.shape\\' must be > 0", use_tensor_inputs_options=[True]) - def testBadParams(self): + def testBadParamsScalarInputs(self): + # In this test, inputs do not get converted to tensors before calling the + # tf.function. The error message here is raised in python + # since the python function has direct access to the scalars. x_shape = [3, 3, 1] x = np.zeros(x_shape) @@ -2187,9 +2191,49 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): "height must be <= target - offset"), (0, 2, 4, 4, "width must be <= target - offset")) - for config_item in test_config: - self._assertRaises(x, x_shape, *config_item) + self._assertRaises( + x, x_shape, *config_item, use_tensor_inputs_options=[False]) + + def testBadParamsTensorInputsEager(self): + # In this test inputs get converted to EagerTensors before calling the + # tf.function. The error message here is raised in python + # since the python function has direct access to the tensor's values. + with context.eager_mode(): + x_shape = [3, 3, 1] + x = np.zeros(x_shape) + + # Each line is a test configuration: + # offset_height, offset_width, target_height, target_width, err_msg + test_config = ( + (-1, 0, 4, 4, + "offset_height must be >= 0"), + (0, -1, 4, 4, + "offset_width must be >= 0"), + (2, 0, 4, 4, + "height must be <= target - offset"), + (0, 2, 4, 4, + "width must be <= target - offset")) + for config_item in test_config: + self._assertRaises( + x, x_shape, *config_item, use_tensor_inputs_options=[True]) + + @parameterized.named_parameters([("OffsetHeight", (-1, 0, 4, 4)), + ("OffsetWidth", (0, -1, 4, 4)), + ("Height", (2, 0, 4, 4)), + ("Width", (0, 2, 4, 4))]) + def testBadParamsTensorInputsGraph(self, config): + # In this test inputs get converted to tensors before calling the + # tf.function. The error message here is raised during shape inference. + with context.graph_mode(): + x_shape = [3, 3, 1] + x = np.zeros(x_shape) + self._assertRaises( + x, + x_shape, + *config, + "Paddings must be non-negative", + use_tensor_inputs_options=[True]) def testNameScope(self): # Testing name scope requires a graph. @@ -5075,7 +5119,6 @@ class PSNRTest(test_util.TensorFlowTestCase): """Returns an image or image batch with given shape.""" return np.random.rand(*shape).astype(np.float32) * max_val - @test_util.run_deprecated_v1 def testPSNRSingleImage(self): image1 = self._RandomImage((8, 8, 1), 1) image2 = self._RandomImage((8, 8, 1), 1) @@ -5086,10 +5129,9 @@ class PSNRTest(test_util.TensorFlowTestCase): dtype=dtypes.float32) tf_image2 = constant_op.constant(image2, shape=image2.shape, dtype=dtypes.float32) - tf_psnr = image_ops.psnr(tf_image1, tf_image2, 1.0, "psnr").eval() + tf_psnr = self.evaluate(image_ops.psnr(tf_image1, tf_image2, 1.0, "psnr")) self.assertAllClose(psnr, tf_psnr, atol=0.001) - @test_util.run_deprecated_v1 def testPSNRMultiImage(self): image1 = self._RandomImage((10, 8, 8, 1), 1) image2 = self._RandomImage((10, 8, 8, 1), 1) @@ -5100,10 +5142,9 @@ class PSNRTest(test_util.TensorFlowTestCase): dtype=dtypes.float32) tf_image2 = constant_op.constant(image2, shape=image2.shape, dtype=dtypes.float32) - tf_psnr = image_ops.psnr(tf_image1, tf_image2, 1, "psnr").eval() + tf_psnr = self.evaluate(image_ops.psnr(tf_image1, tf_image2, 1, "psnr")) self.assertAllClose(psnr, tf_psnr, atol=0.001) - @test_util.run_deprecated_v1 def testGoldenPSNR(self): q20, q72, q95 = self._LoadTestImages() @@ -5121,23 +5162,21 @@ class PSNRTest(test_util.TensorFlowTestCase): tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32) tf_q72 = constant_op.constant(q72, shape=q72.shape, dtype=dtypes.float32) tf_q95 = constant_op.constant(q95, shape=q95.shape, dtype=dtypes.float32) - tf_psnr1 = image_ops.psnr(tf_q20, tf_q72, 1, "psnr1").eval() - tf_psnr2 = image_ops.psnr(tf_q20, tf_q95, 1, "psnr2").eval() - tf_psnr3 = image_ops.psnr(tf_q72, tf_q95, 1, "psnr3").eval() + tf_psnr1 = self.evaluate(image_ops.psnr(tf_q20, tf_q72, 1, "psnr1")) + tf_psnr2 = self.evaluate(image_ops.psnr(tf_q20, tf_q95, 1, "psnr2")) + tf_psnr3 = self.evaluate(image_ops.psnr(tf_q72, tf_q95, 1, "psnr3")) self.assertAllClose(psnr1, tf_psnr1, atol=0.001) self.assertAllClose(psnr2, tf_psnr2, atol=0.001) self.assertAllClose(psnr3, tf_psnr3, atol=0.001) - @test_util.run_deprecated_v1 def testInfinity(self): q20, _, _ = self._LoadTestImages() psnr = self._PSNR_NumPy(q20, q20, 1) with self.cached_session(use_gpu=True): tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32) - tf_psnr = image_ops.psnr(tf_q20, tf_q20, 1, "psnr").eval() + tf_psnr = self.evaluate(image_ops.psnr(tf_q20, tf_q20, 1, "psnr")) self.assertAllClose(psnr, tf_psnr, atol=0.001) - @test_util.run_deprecated_v1 def testInt(self): img1 = self._RandomImage((10, 8, 8, 1), 255) img2 = self._RandomImage((10, 8, 8, 1), 255) @@ -5149,7 +5188,7 @@ class PSNRTest(test_util.TensorFlowTestCase): psnr_float32 = image_ops.psnr(img1, img2, 1.0) with self.cached_session(use_gpu=True): self.assertAllClose( - psnr_uint8.eval(), self.evaluate(psnr_float32), atol=0.001) + self.evaluate(psnr_uint8), self.evaluate(psnr_float32), atol=0.001) class SSIMTest(test_util.TensorFlowTestCase): @@ -5179,18 +5218,21 @@ class SSIMTest(test_util.TensorFlowTestCase): """Returns an image or image batch with given shape.""" return np.random.rand(*shape).astype(np.float32) * max_val - @test_util.run_deprecated_v1 def testAgainstMatlab(self): """Tests against values produced by Matlab.""" img = self._LoadTestImages() expected = self._ssim[np.triu_indices(3)] - ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)] - ssim = image_ops.ssim( - *ph, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03) + def ssim_func(x): + return image_ops.ssim( + *x, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03) + with self.cached_session(use_gpu=True): - scores = [ssim.eval(dict(zip(ph, t))) - for t in itertools.combinations_with_replacement(img, 2)] + scores = [ + self.evaluate(ssim_func(t)) + for t in itertools.combinations_with_replacement(img, 2) + ] + self.assertAllClose(expected, np.squeeze(scores), atol=1e-4) def testBatch(self): @@ -5248,7 +5290,6 @@ class SSIMTest(test_util.TensorFlowTestCase): with self.cached_session(use_gpu=True): self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4) - @test_util.run_deprecated_v1 def testNegative(self): """Tests against negative SSIM index.""" step = np.expand_dims(np.arange(0, 256, 16, dtype=np.uint8), axis=0) @@ -5267,9 +5308,8 @@ class SSIMTest(test_util.TensorFlowTestCase): k1=0.01, k2=0.03) with self.cached_session(use_gpu=True): - self.assertLess(ssim.eval(), 0) + self.assertLess(self.evaluate(ssim), 0) - @test_util.run_deprecated_v1 def testInt(self): img1 = self._RandomImage((1, 16, 16, 3), 255) img2 = self._RandomImage((1, 16, 16, 3), 255) @@ -5283,7 +5323,7 @@ class SSIMTest(test_util.TensorFlowTestCase): img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03) with self.cached_session(use_gpu=True): self.assertAllClose( - ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001) + self.evaluate(ssim_uint8), self.evaluate(ssim_float32), atol=0.001) class MultiscaleSSIMTest(test_util.TensorFlowTestCase): @@ -5518,7 +5558,7 @@ class SobelEdgesTest(test_util.TensorFlowTestCase): @test_util.run_all_in_graph_and_eager_modes -class DecodeImageTest(test_util.TensorFlowTestCase): +class DecodeImageTest(test_util.TensorFlowTestCase, parameterized.TestCase): _FORWARD_COMPATIBILITY_HORIZONS = [ (2020, 1, 1), @@ -5698,7 +5738,7 @@ class DecodeImageTest(test_util.TensorFlowTestCase): first_frame = array_ops.gather(animation, 0) image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32) image0, image1 = self.evaluate([image0, image1]) - self.assertEqual(len(image0.shape), 3) + self.assertLen(image0.shape, 3) self.assertAllEqual(list(image0.shape), [40, 20, 3]) self.assertAllEqual(image0, image1) @@ -5706,10 +5746,103 @@ class DecodeImageTest(test_util.TensorFlowTestCase): image2 = image_ops.decode_image(gif0, dtype=dtypes.float32) image3 = image_ops.convert_image_dtype(animation, dtypes.float32) image2, image3 = self.evaluate([image2, image3]) - self.assertEqual(len(image2.shape), 4) + self.assertLen(image2.shape, 4) self.assertAllEqual(list(image2.shape), [12, 40, 20, 3]) self.assertAllEqual(image2, image3) + def testImageCropAndResize(self): + if test_util.is_gpu_available(): + op = image_ops_impl.crop_and_resize_v2( + image=array_ops.zeros((2, 1, 1, 1)), + boxes=[[1.0e+40, 0, 0, 0]], + box_indices=[1], + crop_size=[1, 1]) + self.evaluate(op) + else: + message = "Boxes contains at least one element that is not finite" + with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError), + message): + op = image_ops_impl.crop_and_resize_v2( + image=array_ops.zeros((2, 1, 1, 1)), + boxes=[[1.0e+40, 0, 0, 0]], + box_indices=[1], + crop_size=[1, 1]) + self.evaluate(op) + + @parameterized.named_parameters( + ("_jpeg", "JPEG", "jpeg_merge_test1.jpg"), + ("_png", "PNG", "lena_rgba.png"), + ("_gif", "GIF", "scan.gif"), + ) + def testWrongOpBmp(self, img_format, filename): + base_folder = "tensorflow/core/lib" + base_path = os.path.join(base_folder, img_format.lower(), "testdata") + err_msg = "Trying to decode " + img_format + " format using DecodeBmp op" + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), err_msg): + img_bytes = io_ops.read_file(os.path.join(base_path, filename)) + img = image_ops.decode_bmp(img_bytes) + self.evaluate(img) + + @parameterized.named_parameters( + ("_jpeg", image_ops.decode_jpeg, "DecodeJpeg"), + ("_png", image_ops.decode_png, "DecodePng"), + ("_gif", image_ops.decode_gif, "DecodeGif"), + ) + def testWrongOp(self, decode_op, op_used): + base = "tensorflow/core/lib/bmp/testdata" + bmp0 = io_ops.read_file(os.path.join(base, "rgba_small.bmp")) + err_msg = ("Trying to decode BMP format using a wrong op. Use `decode_bmp` " + "or `decode_image` instead. Op used: ") + op_used + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), err_msg): + img = decode_op(bmp0) + self.evaluate(img) + + @parameterized.named_parameters( + ("_png", "PNG", "lena_rgba.png"), + ("_gif", "GIF", "scan.gif"), + ("_bmp", "BMP", "rgba_small.bmp"), + ) + def testWrongOpJpeg(self, img_format, filename): + base_folder = "tensorflow/core/lib" + base_path = os.path.join(base_folder, img_format.lower(), "testdata") + err_msg = ("DecodeAndCropJpeg operation can run on JPEG only, but " + "detected ") + img_format + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), err_msg): + img_bytes = io_ops.read_file(os.path.join(base_path, filename)) + img = image_ops.decode_and_crop_jpeg(img_bytes, [1, 1, 2, 2]) + self.evaluate(img) + + def testGifFramesWithDiffSize(self): + """Test decoding an animated GIF. + + This test verifies that `decode_image` op can decode animated GIFs whose + first frame does not fill the canvas. The unoccupied areas should be filled + with zeros (black). + + `squares.gif` is animated with two images of different sizes. It + alternates between a smaller image of size 10 x 10 and a larger image of + size 16 x 16. Because it starts animating with the smaller image, the first + frame does not fill the canvas. (Canvas size is equal to max frame width x + max frame height.) + + `red_black.gif` has just a single image in a GIF format. It is the same + image as the smaller image (size 10 x 10) of the two images in + `squares.gif`. The only difference is that its background (canvas - smaller + image) is pre-filled with zeros (black); it is the groundtruth. + """ + base = "tensorflow/core/lib/gif/testdata" + gif_bytes0 = io_ops.read_file(os.path.join(base, "squares.gif")) + image0 = image_ops.decode_image(gif_bytes0, dtype=dtypes.float32, + expand_animations=False) + gif_bytes1 = io_ops.read_file(os.path.join(base, "red_black.gif")) + image1 = image_ops.decode_image(gif_bytes1, dtype=dtypes.float32) + image1_0 = array_ops.gather(image1, 0) + image0, image1_0 = self.evaluate([image0, image1_0]) + self.assertAllEqual(image0, image1_0) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py index 4ea7ef007d6..ae8bfbdbdd0 100644 --- a/tensorflow/python/ops/init_ops_test.py +++ b/tensorflow/python/ops/init_ops_test.py @@ -203,8 +203,6 @@ class InitializersTest(test.TestCase): run_metadata=run_metadata) @test_util.run_gpu_only - @test_util.disable_tfrt('b/165614506: Incorrect device name set in ' - 'tfrt::TensorHandle.') def test_eager_orthogonal_gpu(self): with context.eager_mode(): v = variable_scope.get_variable( diff --git a/tensorflow/python/ops/init_ops_v2.py b/tensorflow/python/ops/init_ops_v2.py index 3c110fe9cf9..ef020eef81b 100644 --- a/tensorflow/python/ops/init_ops_v2.py +++ b/tensorflow/python/ops/init_ops_v2.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Operations often used for initializing tensors. - -All variable initializers returned by functions in this file should have the -following signature: - -def _initializer(shape, dtype=dtypes.float32): - Args: - shape: List of `int` representing the shape of the output `Tensor`. Some - initializers may also be able to accept a `Tensor`. - dtype: (Optional) Type of the output `Tensor`. - Returns: - A `Tensor` of type `dtype` and `shape`. -""" +"""Initializers for TF 2.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -44,18 +32,40 @@ from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops.init_ops import _compute_fans from tensorflow.python.util.tf_export import tf_export +_PARTITION_SHAPE = "partition_shape" +_PARTITION_OFFSET = "partition_offset" + class Initializer(object): """Initializer base class: all initializers inherit from this class. + + Initializers should implement a `__call__` method with the following + signature: + + ```python + def __call__(self, shape, dtype=None, **kwargs): + # returns a tensor of shape `shape` and dtype `dtype` + # containing values drawn from a distribution of your choice. + ``` """ - def __call__(self, shape, dtype=None): + def __call__(self, shape, dtype=None, **kwargs): """Returns a tensor object initialized as specified by the initializer. Args: shape: Shape of the tensor. dtype: Optional dtype of the tensor. If not provided will return tensor - of `tf.float32`. + of `tf.float32`. + **kwargs: Additional keyword arguments. Accepted values: + `partition_shape` and `partition_offset`. Used when creating a single + partition in a partitioned variable. `partition_shape` is the shape + of the partition (i.e. the shape of the returned tensor) and + `partition_offset` is a tuple of `int` specifying the offset of this + partition w.r.t each axis. For example, a tensor of shape `(30, 100)` + can be partitioned into two partitions: `p0` of shape `(10, 100)` and + `p1` of shape `(20, 100)`; if the initializer is called with + `partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should + return the value for `p1`. """ raise NotImplementedError @@ -89,6 +99,14 @@ class Initializer(object): config.pop("dtype", None) return cls(**config) + def _validate_kwargs(self, kwargs, support_partition=True): + for kwarg in kwargs: + if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]: + raise TypeError("Unknown keyword arguments: %s" % kwarg) + elif not support_partition: + raise ValueError("%s initializer doesn't support partition-related" + " arguments" % self.__class__.__name__) + @tf_export("zeros_initializer", v1=[]) class Zeros(Initializer): @@ -115,20 +133,24 @@ class Zeros(Initializer): (, , >> x = tf.constant([4, float('nan')]) >>> print(tf.reduce_max(x)) - tf.Tensor(4.0, shape=(), dtype=float32) + tf.Tensor(nan, shape=(), dtype=float32) >>> x = tf.constant([float('nan'), float('nan')]) >>> print(tf.reduce_max(x)) - tf.Tensor(-inf, shape=(), dtype=float32) + tf.Tensor(nan, shape=(), dtype=float32) >>> x = tf.constant([float('-inf'), float('inf')]) >>> print(tf.reduce_max(x)) tf.Tensor(inf, shape=(), dtype=float32) @@ -4809,6 +4809,35 @@ def ndtri(x, name=None): return gen_math_ops.ndtri(x) +@tf_export("math.erfcinv") +@dispatch.add_dispatch_support +def erfcinv(x, name=None): + """Computes the inverse of complementary error function. + + Given `x`, compute the inverse complementary error function of `x`. + This function is the inverse of `tf.math.erfc`, and is defined on + `[0, 2]`. + + >>> tf.math.erfcinv([0., 0.2, 1., 1.5, 2.]) + + + Args: + x: `Tensor` with type `float` or `double`. + name: A name for the operation (optional). + Returns: + Inverse complementary error function of `x`. + + @compatibility(numpy) + Equivalent to scipy.special.erfcinv + @end_compatibility + """ + with ops.name_scope(name, "erfcinv", [x]): + x = ops.convert_to_tensor(x, name="start") + return -ndtri(0.5 * x) * np.sqrt(0.5) + + @tf_export("math.ceil", v1=["math.ceil", "ceil"]) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints("ceil") diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index dabf4bb9d33..f2e637b5b09 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes @@ -882,5 +883,20 @@ class RangeTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, self.evaluate(tensor)) +@test_util.run_all_in_graph_and_eager_modes +class ErfcinvTest(test_util.TensorFlowTestCase): + + def testErfcinv(self): + if test.is_built_with_rocm(): + # The implementation of erfcinv calls ndtri op, + # and the ROCm implementaion for ndtri op has a known bug in it + # whose fix will be in a forthcoming ROCm release (4.0 ?). + # Need to skip this unit-test until that ROCm release is out + self.skipTest("ndtri op implementation is buggy on ROCm") + values = np.random.uniform(0.1, 1.9, size=int(1e4)).astype(np.float32) + approx_id = math_ops.erfc(math_ops.erfcinv(values)) + self.assertAllClose(values, self.evaluate(approx_id)) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index d22c96e50c8..96eb5509de3 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -43,7 +43,6 @@ py_library( cuda_py_test( name = "np_arrays_test", srcs = ["np_arrays_test.py"], - tfrt_enabled = True, deps = [ ":numpy", "//tensorflow/python:dtypes", @@ -69,7 +68,6 @@ cuda_py_test( cuda_py_test( name = "np_logic_test", srcs = ["np_logic_test.py"], - tfrt_enabled = True, deps = [ ":numpy", "//third_party/py/numpy", @@ -102,7 +100,6 @@ cuda_py_test( cuda_py_test( name = "np_utils_test", srcs = ["np_utils_test.py"], - tfrt_enabled = True, deps = [ ":numpy", "//tensorflow/python:platform", diff --git a/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_NumPy_Text_Generation.ipynb b/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_NumPy_Text_Generation.ipynb new file mode 100644 index 00000000000..b9d346d015a --- /dev/null +++ b/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_NumPy_Text_Generation.ipynb @@ -0,0 +1,1088 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "t09eeeR5prIJ" + }, + "source": [ + "##### Copyright 2020 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "id": "GCCk8_dHpuNf" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ovpZyIhNIgoq" + }, + "source": [ + "# Text generation with an RNN" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hcD2nPQvPOFM" + }, + "source": [ + "
    \n", + " \n", + " \n", + " \n", + " \n", + "
    \n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
    " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BwpJ5IffzRG6" + }, + "source": [ + "This tutorial demonstrates how to generate text using a character-based RNN. We will work with a dataset of Shakespeare's writing from Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). Given a sequence of characters from this data (\"Shakespear\"), train a model to predict the next character in the sequence (\"e\"). Longer sequences of text can be generated by calling the model repeatedly.\n", + "\n", + "Note: Enable GPU acceleration to execute this notebook faster. In Colab: *Runtime > Change runtime type > Hardware acclerator > GPU*. If running locally make sure TensorFlow version >= 2.4.\n", + "\n", + "This tutorial includes runnable code implemented using [tf.experimental.numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy). The following is sample output when the model in this tutorial trained for 30 epochs, and started with the string \"Q\":\n", + "\n", + "
    \n",
    +    "QUEENE:\n",
    +    "I had thought thou hadst a Roman; for the oracle,\n",
    +    "Thus by All bids the man against the word,\n",
    +    "Which are so weak of care, by old care done;\n",
    +    "Your children were in your holy love,\n",
    +    "And the precipitation through the bleeding throne.\n",
    +    "\n",
    +    "BISHOP OF ELY:\n",
    +    "Marry, and will, my lord, to weep in such a one were prettiest;\n",
    +    "Yet now I was adopted heir\n",
    +    "Of the world's lamentable day,\n",
    +    "To watch the next way with his father with his face?\n",
    +    "\n",
    +    "ESCALUS:\n",
    +    "The cause why then we are all resolved more sons.\n",
    +    "\n",
    +    "VOLUMNIA:\n",
    +    "O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,\n",
    +    "And love and pale as any will to that word.\n",
    +    "\n",
    +    "QUEEN ELIZABETH:\n",
    +    "But how long have I heard the soul for this world,\n",
    +    "And show his hands of life be proved to stand.\n",
    +    "\n",
    +    "PETRUCHIO:\n",
    +    "I say he look'd on, if I must be content\n",
    +    "To stay him from the fatal of our country's bliss.\n",
    +    "His lordship pluck'd from this sentence then for prey,\n",
    +    "And then let us twain, being the moon,\n",
    +    "were she such a case as fills m\n",
    +    "
    \n", + "\n", + "While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:\n", + "\n", + "* The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.\n", + "\n", + "* The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.\n", + "\n", + "* As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "srXC6pLGLwS6" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WGyKZj3bzf9p" + }, + "source": [ + "### Import TensorFlow and other libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "yG_n40gFzf9s" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow.experimental.numpy as tnp\n", + "\n", + "import numpy as np\n", + "import os\n", + "import time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EHDoRoc5PKWz" + }, + "source": [ + "### Download the Shakespeare dataset\n", + "\n", + "Change the following line to run this code on your own data." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "pD_55cOxLkAb" + }, + "outputs": [], + "source": [ + "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UHjdCjDuSvX_" + }, + "source": [ + "### Read the data\n", + "\n", + "First, look in the text:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "aavnuByVymwK" + }, + "outputs": [], + "source": [ + "# Read, then decode for py2 compat.\n", + "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", + "# length of text is the number of characters in it\n", + "print ('Length of text: {} characters'.format(len(text)))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "Duhg9NrUymwO" + }, + "outputs": [], + "source": [ + "# Take a look at the first 250 characters in text\n", + "print(text[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "IlCgQBRVymwR" + }, + "outputs": [], + "source": [ + "# The unique characters in the file\n", + "vocab = sorted(set(text))\n", + "print ('{} unique characters'.format(len(vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rNnrKn_lL-IJ" + }, + "source": [ + "## Process the text" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LFjSVAlWzf-N" + }, + "source": [ + "### Vectorize the text\n", + "\n", + "Before training, we need to map strings to a numerical representation. Create two lookup tables: one mapping characters to numbers, and another for numbers to characters." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "IalZLbvOzf-F" + }, + "outputs": [], + "source": [ + "# Creating a mapping from unique characters to indices\n", + "char2idx = {u:i for i, u in enumerate(vocab)}\n", + "idx2char = np.array(vocab)\n", + "\n", + "text_as_int = np.array([char2idx[c] for c in text])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bbmsf23Bymwe" + }, + "source": [ + "### The prediction task" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wssHQ1oGymwe" + }, + "source": [ + "Given a character, or a sequence of characters, what is the most probable next character? This is the task we're training the model to perform. The input to the model will be a sequence of characters, and we train the model to predict the output—the following character at each time step.\n", + "\n", + "Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hgsVvVxnymwf" + }, + "source": [ + "### Create training examples and targets\n", + "\n", + "Next divide the text into example sequences. Each input sequence will contain `seq_length` characters from the text.\n", + "\n", + "For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.\n", + "\n", + "So break the text into chunks of `seq_length+1`. For example, say `seq_length` is 4 and our text is \"Hello\". The input sequence would be \"Hell\", and the target sequence \"ello\"." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "0UHJDA39zf-O" + }, + "outputs": [], + "source": [ + "# The maximum length sentence we want for a single input in characters\n", + "seq_length = 100\n", + "examples_per_epoch = len(text)//(seq_length+1)\n", + "\n", + "# Create training examples / targets\n", + "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n", + "\n", + "for i in char_dataset.take(5):\n", + " print(idx2char[i.numpy()])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "l4hkDU3i7ozi" + }, + "outputs": [], + "source": [ + "sequences = char_dataset.batch(seq_length+1, drop_remainder=True)\n", + "\n", + "for item in sequences.take(5):\n", + " print(repr(''.join(idx2char[item.numpy()])))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "9NGu-FkO_kYU" + }, + "outputs": [], + "source": [ + "def split_input_target(chunk):\n", + " input_text = chunk[:-1]\n", + " target_text = chunk[1:]\n", + " return input_text, target_text\n", + "\n", + "dataset = sequences.map(split_input_target)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "GNbw-iR0ymwj" + }, + "outputs": [], + "source": [ + "for input_example, target_example in dataset.take(1):\n", + " print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))\n", + " print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_33OHL3b84i0" + }, + "source": [ + "Each index of these vectors are processed as one time step. For the input at time step 0, the model receives the index for \"F\" and trys to predict the index for \"i\" as the next character. At the next timestep, it does the same thing but the `RNN` considers the previous step context in addition to the current input character." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "0eBu9WZG84i0" + }, + "outputs": [], + "source": [ + "for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):\n", + " print(\"Step {:4d}\".format(i))\n", + " print(\" input: {} ({:s})\".format(input_idx, repr(idx2char[input_idx])))\n", + " print(\" expected output: {} ({:s})\".format(target_idx, repr(idx2char[target_idx])))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MJdfPmdqzf-R" + }, + "source": [ + "### Create training batches\n", + "\n", + "We used `tf.data` to split the text into manageable sequences. But before feeding this data into the model, we need to shuffle the data and pack it into batches." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "p2pGotuNzf-S" + }, + "outputs": [], + "source": [ + "# Batch size\n", + "BATCH_SIZE = 64\n", + "\n", + "# Buffer size to shuffle the dataset\n", + "# (TF data is designed to work with possibly infinite sequences,\n", + "# so it doesn't attempt to shuffle the entire sequence in memory. Instead,\n", + "# it maintains a buffer in which it shuffles elements).\n", + "BUFFER_SIZE = 10000\n", + "\n", + "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n", + "\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r6oUuElIMgVx" + }, + "source": [ + "## Build The Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oKesgVJRY9g4" + }, + "source": [ + "We manually implement the model from scratch, using `tf.numpy` and some low-level TF ops. A `Model` object has three layers: `Embedding`, `GRU` and `Dense`. `Embedding` and `Dense` are little more than just wrappers around `tnp.take` and `tnp.dot`, but we can use them to familiarize ourself with the structure of a layer. Each layer has two essential methods: `build` and `__call__`. `build` creates and initializes the layer's weights and state, which are things that change during the training process. `__call__` is the forward function that calculates outputs given inputs, using the layer's weights and state internally." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9dm_WoL29UmO" + }, + "source": [ + "Our model (more precisely the `GRU` layer) is stateful, because each call of `__call__` will change its internal state, affecting the next call. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "zHT8cLh7EAsg" + }, + "outputs": [], + "source": [ + "# Length of the vocabulary in chars\n", + "vocab_size = len(vocab)\n", + "\n", + "# The embedding dimension\n", + "embedding_dim = 256\n", + "\n", + "# Number of RNN units\n", + "rnn_units = 1024" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "dGrbGm-oGnqB" + }, + "outputs": [], + "source": [ + "class Embedding:\n", + "\n", + " def __init__(self, vocab_size, embedding_dim):\n", + " self._vocab_size = vocab_size\n", + " self._embedding_dim = embedding_dim\n", + " self._built = False\n", + "\n", + " def __call__(self, inputs):\n", + " if not self._built:\n", + " self.build(inputs)\n", + " return tnp.take(self.weights, inputs, axis=0)\n", + "\n", + " def build(self, inputs):\n", + " del inputs\n", + " self.weights = tf.Variable(tnp.random.randn(\n", + " self._vocab_size, self._embedding_dim).astype(np.float32))\n", + " self._built = True\n", + "\n", + "\n", + "class GRUCell:\n", + " \"\"\"Builds a traditional GRU cell with dense internal transformations.\n", + "\n", + " Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555\n", + " \"\"\"\n", + "\n", + " def __init__(self, n_units, forget_bias=0.0):\n", + " self._n_units = n_units\n", + " self._forget_bias = forget_bias\n", + " self._built = False\n", + "\n", + " def __call__(self, inputs):\n", + " if not self._built:\n", + " self.build(inputs)\n", + " x, gru_state = inputs\n", + " # Dense layer on the concatenation of x and h.\n", + " y = tnp.dot(tnp.concatenate([x, gru_state], axis=-1), self.w1) + self.b1\n", + " # Update and reset gates.\n", + " u, r = tnp.split(tf.sigmoid(y), 2, axis=-1)\n", + " # Candidate.\n", + " c = tnp.dot(tnp.concatenate([x, r * gru_state], axis=-1), self.w2) + self.b2\n", + " new_gru_state = u * gru_state + (1 - u) * tnp.tanh(c)\n", + " return new_gru_state\n", + "\n", + " def build(self, inputs):\n", + " # State last dimension must be n_units.\n", + " assert inputs[1].shape[-1] == self._n_units\n", + " # The dense layer input is the input and half of the GRU state.\n", + " dense_shape = inputs[0].shape[-1] + self._n_units\n", + " self.w1 = tf.Variable(tnp.random.uniform(\n", + " -0.01, 0.01, (dense_shape, 2 * self._n_units)).astype(tnp.float32))\n", + " self.b1 = tf.Variable((tnp.random.randn(2 * self._n_units) * 1e-6 + self._forget_bias\n", + " ).astype(tnp.float32))\n", + " self.w2 = tf.Variable(tnp.random.uniform(\n", + " -0.01, 0.01, (dense_shape, self._n_units)).astype(tnp.float32))\n", + " self.b2 = tf.Variable((tnp.random.randn(self._n_units) * 1e-6).astype(tnp.float32))\n", + " self._built = True\n", + "\n", + " @property\n", + " def weights(self):\n", + " return (self.w1, self.b1, self.w2, self.b2)\n", + "\n", + "\n", + "class GRU:\n", + "\n", + " def __init__(self, n_units, forget_bias=0.0, stateful=False):\n", + " self._cell = GRUCell(n_units, forget_bias)\n", + " self._stateful = stateful\n", + " self._built = False\n", + "\n", + " def __call__(self, inputs):\n", + " if not self._built:\n", + " self.build(inputs)\n", + " if self._stateful:\n", + " state = self.state.read_value()\n", + " else:\n", + " state = self._init_state(inputs.shape[0]) \n", + " inputs = tnp.transpose(inputs, (1, 0, 2))\n", + " output = tf.scan(\n", + " lambda gru_state, x: self._cell((x, gru_state)),\n", + " inputs, state)\n", + " if self._stateful:\n", + " self.state.assign(output[-1, ...])\n", + " return tnp.transpose(output, [1, 0, 2])\n", + "\n", + " def _init_state(self, batch_size):\n", + " return tnp.zeros([batch_size, self._cell._n_units], tnp.float32)\n", + "\n", + " def reset_state(self):\n", + " if not self._stateful:\n", + " return\n", + " self.state.assign(tf.zeros_like(self.state))\n", + "\n", + " def create_state(self, batch_size):\n", + " self.state = tf.Variable(self._init_state(batch_size))\n", + "\n", + " def build(self, inputs):\n", + " s = inputs.shape[0:1] + inputs.shape[2:]\n", + " shapes = (s, s[:-1] + (self._cell._n_units,)) \n", + " self._cell.build([tf.TensorSpec(x, tf.float32) for x in shapes])\n", + " if self._stateful:\n", + " self.create_state(inputs.shape[0])\n", + " else:\n", + " self.state = ()\n", + " self._built = True\n", + " \n", + " @property\n", + " def weights(self):\n", + " return self._cell.weights\n", + "\n", + "\n", + "class Dense:\n", + "\n", + " def __init__(self, n_units, activation=None):\n", + " self._n_units = n_units\n", + " self._activation = activation\n", + " self._built = False\n", + "\n", + " def __call__(self, inputs):\n", + " if not self._built:\n", + " self.build(inputs)\n", + " y = tnp.dot(inputs, self.w) +self.b\n", + " if self._activation != None:\n", + " y = self._activation(y)\n", + " return y\n", + "\n", + " def build(self, inputs):\n", + " shape_w = (inputs.shape[-1], self._n_units)\n", + " lim = tnp.sqrt(6.0 / (shape_w[0] + shape_w[1]))\n", + " self.w = tf.Variable(tnp.random.uniform(-lim, lim, shape_w).astype(tnp.float32))\n", + " self.b = tf.Variable((tnp.random.randn(self._n_units) * 1e-6).astype(tnp.float32))\n", + " self._built = True\n", + "\n", + " @property\n", + " def weights(self):\n", + " return (self.w, self.b)\n", + "\n", + "\n", + "class Model:\n", + "\n", + " def __init__(self, vocab_size, embedding_dim, rnn_units, forget_bias=0.0, stateful=False, activation=None):\n", + " self._embedding = Embedding(vocab_size, embedding_dim)\n", + " self._gru = GRU(rnn_units, forget_bias=forget_bias, stateful=stateful)\n", + " self._dense = Dense(vocab_size, activation=activation)\n", + " self._layers = [self._embedding, self._gru, self._dense]\n", + " self._built = False\n", + "\n", + " def __call__(self, inputs):\n", + " if not self._built:\n", + " self.build(inputs)\n", + " xs = inputs\n", + " for layer in self._layers:\n", + " xs = layer(xs)\n", + " return xs\n", + " \n", + " def build(self, inputs):\n", + " self._embedding.build(inputs)\n", + " self._gru.build(tf.TensorSpec(inputs.shape + (self._embedding._embedding_dim,), tf.float32))\n", + " self._dense.build(tf.TensorSpec(inputs.shape + (self._gru._cell._n_units,), tf.float32))\n", + " self._built = True\n", + "\n", + " @property\n", + " def weights(self):\n", + " return [layer.weights for layer in self._layers]\n", + "\n", + " @property\n", + " def state(self):\n", + " return self._gru.state\n", + "\n", + " def create_state(self, *args):\n", + " self._gru.create_state(*args)\n", + "\n", + " def reset_state(self, *args):\n", + " self._gru.reset_state(*args)\n", + "\n", + "\n", + "model = Model(\n", + " vocab_size = vocab_size,\n", + " embedding_dim=embedding_dim,\n", + " rnn_units=rnn_units,\n", + " stateful=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RkA5upJIJ7W7" + }, + "source": [ + "For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ubPo0_9Prjb" + }, + "source": [ + "## Try the model\n", + "\n", + "Now run the model to see that it behaves as expected.\n", + "\n", + "First check the shape of the output:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "lzuvs0a4IR6m" + }, + "outputs": [], + "source": [ + " for input_example_batch, target_example_batch in dataset.take(1):\n", + " input_example_batch = tnp.asarray(input_example_batch)\n", + " example_batch_predictions = model(input_example_batch)\n", + " print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q6NzLBi4VM4o" + }, + "source": [ + "In the above example the sequence length of the input is `100` but the model can be run on inputs of any length:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uwv0gEkURfx1" + }, + "source": [ + "To get actual predictions from the model we need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.\n", + "\n", + "Note: It is important to _sample_ from this distribution as taking the _argmax_ of the distribution can easily get the model stuck in a loop.\n", + "\n", + "Try it for the first example in the batch:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "RP56TbSEgcNp" + }, + "outputs": [], + "source": [ + "example_batch_predictions[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "4V4MfFg0RQJg" + }, + "outputs": [], + "source": [ + "sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)\n", + "sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QM1Vbxs_URw5" + }, + "source": [ + "This gives us, at each timestep, a prediction of the next character index:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "YqFMUQc_UFgM" + }, + "outputs": [], + "source": [ + "sampled_indices" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LfLtsP3mUhCG" + }, + "source": [ + "Decode these to see the text predicted by this untrained model:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "xWcFwPwLSo05" + }, + "outputs": [], + "source": [ + "print(\"Input: \\n\", repr(\"\".join(idx2char[input_example_batch[0]])))\n", + "print()\n", + "print(\"Next Char Predictions: \\n\", repr(\"\".join(idx2char[sampled_indices ])))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LJL0Q0YPY6Ee" + }, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YCbHQHiaa4Ic" + }, + "source": [ + "At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "trpqTWyvk0nr" + }, + "source": [ + "### Loss function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mJRGdSfi-2D8" + }, + "source": [ + "We define the loss function from scratch, using `tf.nn.log_softmax`. (Our definition is the same as `tf.keras.losses.sparse_categorical_crossentropy`.)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "Dhv7DC6TZ-2i" + }, + "outputs": [], + "source": [ + "def one_hot(labels, n):\n", + " return (labels[..., np.newaxis] == tnp.arange(n)).astype(np.float32)\n", + "\n", + "def loss_fn(labels, predictions):\n", + " predictions = tf.nn.log_softmax(predictions)\n", + " return -tnp.sum(predictions * one_hot(tnp.asarray(labels), predictions.shape[-1]), axis=-1)\n", + "\n", + "example_batch_loss = loss_fn(target_example_batch, example_batch_predictions)\n", + "print(\"Prediction shape: \", example_batch_predictions.shape, \" # (batch_size, sequence_length, vocab_size)\")\n", + "print(\"scalar_loss: \", tnp.mean(example_batch_loss))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mHQWnJCY_fBu" + }, + "source": [ + "### Optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4jHj8s57_NCk" + }, + "source": [ + "Keeping the DIY spirit, we implement the Adam optimizer from scratch." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "NJDx4_SN5Vse" + }, + "outputs": [], + "source": [ + "class Adam:\n", + "\n", + " def __init__(self, learning_rate=0.001, b1=0.9, b2=0.999, eps=1e-7):\n", + " self._lr = learning_rate\n", + " self._b1 = b1\n", + " self._b2 = b2\n", + " self._eps = eps\n", + " self._built = False\n", + "\n", + " def build(self, weights):\n", + " self._m = tf.nest.map_structure(lambda x: tf.Variable(tnp.zeros_like(x)), weights)\n", + " self._v = tf.nest.map_structure(lambda x: tf.Variable(tnp.zeros_like(x)), weights)\n", + " self._step = tf.Variable(tnp.asarray(0, np.int64))\n", + " self._built = True\n", + "\n", + " def _update(self, weights_var, grads, m_var, v_var):\n", + " b1 = self._b1\n", + " b2 = self._b2\n", + " eps = self._eps\n", + " step = tnp.asarray(self._step, np.float32)\n", + " lr = self._lr\n", + " weights = tnp.asarray(weights_var)\n", + " m = tnp.asarray(m_var)\n", + " v = tnp.asarray(v_var)\n", + " m = (1 - b1) * grads + b1 * m # First moment estimate.\n", + " v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate.\n", + " mhat = m / (1 - b1 ** (step + 1)) # Bias correction.\n", + " vhat = v / (1 - b2 ** (step + 1)) \n", + " weights_var.assign_sub((lr * mhat / (tnp.sqrt(vhat) + eps)).astype(weights.dtype))\n", + " m_var.assign(m)\n", + " v_var.assign(v)\n", + "\n", + " def apply_gradients(self, weights, grads):\n", + " if not self._built:\n", + " self.build(weights)\n", + " tf.nest.map_structure(lambda *args: self._update(*args), weights, grads, self._m, self._v)\n", + " self._step.assign_add(1)\n", + "\n", + " @property\n", + " def state(self):\n", + " return (self._step, self._m, self._v)\n", + "\n", + "\n", + "optimizer = Adam()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Ky3F_BhgkTW" + }, + "source": [ + "### Training loop" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EdhARuXFACy0" + }, + "source": [ + "Again, we write our training loop from scratch." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IxdOA-rgyGvs" + }, + "source": [ + "To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "Q4nN6i0oirh2" + }, + "outputs": [], + "source": [ + "@tf.function\n", + "def train_step(inp, target):\n", + " with tf.GradientTape() as tape:\n", + " # tape.watch(tf.nest.flatten(weights))\n", + " predictions = model(inp)\n", + " loss = tnp.mean(loss_fn(target, predictions))\n", + " weights = model.weights\n", + " grads = tape.gradient(loss, weights)\n", + " optimizer.apply_gradients(weights, grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "RIq1BwUD5mRQ" + }, + "outputs": [], + "source": [ + "# Training step\n", + "EPOCHS = 10\n", + "\n", + "model.create_state(BATCH_SIZE)\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + "\n", + " # initializing the hidden state at the start of every epoch\n", + " model.reset_state()\n", + "\n", + " for (batch_n, (inp, target)) in enumerate(dataset):\n", + " loss = train_step(inp, target)\n", + "\n", + " if batch_n % 100 == 0:\n", + " template = 'Epoch {} Batch {} Loss {}'\n", + " print(template.format(epoch+1, batch_n, loss))\n", + "\n", + " print ('Epoch {} Loss {}'.format(epoch+1, loss))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kKkD5M6eoSiN" + }, + "source": [ + "## Generate text" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DjGz1tDkzf-u" + }, + "source": [ + "The following code block generates the text:\n", + "\n", + "* It Starts by choosing a start string, initializing the RNN state and setting the number of characters to generate.\n", + "\n", + "* Get the prediction distribution of the next character using the start string and the RNN state.\n", + "\n", + "* Then, use a categorical distribution to calculate the index of the predicted character. Use this predicted character as our next input to the model.\n", + "\n", + "* The RNN state returned by the model is fed back into the model so that it now has more context, instead than only one character. After predicting the next character, the modified RNN states are again fed back into the model, which is how it learns as it gets more context from the previously predicted characters.\n", + "\n", + "Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LyeYRiuVjodY" + }, + "source": [ + "To keep this prediction step simple, use a batch size of 1." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "WvuwZBX5Ogfd" + }, + "outputs": [], + "source": [ + "def generate_text(model, start_string):\n", + " # Evaluation step (generating text using the learned model)\n", + "\n", + " # Number of characters to generate\n", + " num_generate = 1000\n", + "\n", + " # Converting our start string to numbers (vectorizing)\n", + " input_eval = [char2idx[s] for s in start_string]\n", + " input_eval = tf.expand_dims(input_eval, 0)\n", + "\n", + " # Empty string to store our results\n", + " text_generated = []\n", + "\n", + " # Low temperatures results in more predictable text.\n", + " # Higher temperatures results in more surprising text.\n", + " # Experiment to find the best setting.\n", + " temperature = 1.0\n", + "\n", + " # Here batch size == 1\n", + " model.create_state(1)\n", + " for i in range(num_generate):\n", + " predictions = model(input_eval)\n", + " # remove the batch dimension\n", + " predictions = tf.squeeze(predictions, 0)\n", + "\n", + " # using a categorical distribution to predict the character returned by the model\n", + " predictions = predictions / temperature\n", + " predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()\n", + "\n", + " # We pass the predicted character as the next input to the model\n", + " # along with the previous hidden state\n", + " input_eval = tf.expand_dims([predicted_id], 0)\n", + "\n", + " text_generated.append(idx2char[predicted_id])\n", + "\n", + " return (start_string + ''.join(text_generated))\n", + "\n", + "print(generate_text(model, start_string=u\"ROMEO: \"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AM2Uma_-yVIq" + }, + "source": [ + "The easiest thing you can do to improve the results it to train it for longer (try `EPOCHS=30`).\n", + "\n", + "You can also experiment with a different start string, or try adding another RNN layer to improve the model's accuracy, or adjusting the temperature parameter to generate more or less random predictions." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "TensorFlow_NumPy_Text_Generation.ipynb", + "private_outputs": true, + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 369297bd85a..64542656273 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -822,16 +822,32 @@ def transpose(a, axes=None): @np_utils.np_doc('swapaxes') def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring a = asarray(a).data + def adjust_axes(axes, rank): + def f(x): + if isinstance(x, int): + if x < 0: + x = x + rank + else: + x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) + return x + return nest.map_structure(f, axes) - a_rank = array_ops.rank(a) - axis1 = array_ops.where_v2(axis1 < 0, axis1 + a_rank, axis1) - axis2 = array_ops.where_v2(axis2 < 0, axis2 + a_rank, axis2) - - perm = math_ops.range(a_rank) - perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], - [axis2, axis1]) + if (a.shape.rank is not None and + isinstance(axis1, int) and isinstance(axis2, int)): + # This branch makes sure `perm` is statically known, to avoid a + # not-compile-time-constant XLA error. + a_rank = a.shape.rank + axis1, axis2 = adjust_axes((axis1, axis2), a_rank) + perm = list(range(a_rank)) + perm[axis1] = axis2 + perm[axis2] = axis1 + else: + a_rank = array_ops.rank(a) + axis1, axis2 = adjust_axes((axis1, axis2), a_rank) + perm = math_ops.range(a_rank) + perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], + [axis2, axis1]) a = array_ops.transpose(a, perm) - return np_utils.tensor_to_ndarray(a) diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 119a944e867..9208073d946 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -111,7 +111,10 @@ py_library( cuda_py_test( name = "control_flow_ops_test", srcs = ["control_flow_ops_test.py"], - tags = ["no_rocm"], + shard_count = 16, + tags = [ + "no_rocm", + ], deps = [ ":control_flow_ops", ":test_util", @@ -154,6 +157,9 @@ cuda_py_test( cuda_py_test( name = "array_test", srcs = ["array_test.py"], + tags = [ + "notsan", # TODO(b/170999669): Data race + ], deps = [ ":control_flow_ops", ":test_util", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 3a0c6cf1a14..f641687e990 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -174,6 +174,7 @@ class PForTest(PForTestCase): pfor_control_flow_ops.vectorized_map( lambda x: x * x, math_ops.range(4))) self.assertTrue(def_function.functions_run_eagerly()) + def_function.run_functions_eagerly(False) @test_util.run_all_in_graph_and_eager_modes @@ -970,6 +971,65 @@ class TensorListTest(PForTestCase): self._test_loop_fn(loop_fn, 2) + def test_create_outside_and_push_back(self): + h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + + def loop_fn(i): + handle = list_ops.tensor_list_push_back(h, [i, 2]) + handle = list_ops.tensor_list_push_back(handle, [1, 2]) + handle = list_ops.tensor_list_push_back(handle, [1, 2]) + return list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_create_inside_and_push_back(self): + + def loop_fn(i): + handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + handle = list_ops.tensor_list_push_back(handle, [i, 2]) + handle = list_ops.tensor_list_push_back(handle, [1, 2]) + return list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_pop_back_no_shape(self): + + def loop_fn(i): + handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + handle = list_ops.tensor_list_push_back(handle, [1, 2]) + handle = list_ops.tensor_list_push_back(handle, [i, 2]) + handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32) + return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_pop_back_no_shape_capture(self): + h = list_ops.tensor_list_reserve([2], 1, dtypes.int32) + h = list_ops.tensor_list_push_back(h, [1, 2]) + + def loop_fn(i): + handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32) + handle = list_ops.tensor_list_push_back(handle, [1, i]) + return tensor, list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_pop_back_with_shape(self): + + @def_function.function + def loop_fn(i): + with backprop.GradientTape() as tape: + handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32) + x = math_ops.cast(i, dtypes.float32)[None] + tape.watch(x) + handle = list_ops.tensor_list_push_back(handle, x) + stacked = list_ops.tensor_list_stack(handle, dtypes.float32) + list_grad = tape.gradient(stacked, x, x) + self.assertEqual("TensorListPopBack", list_grad.op.type) + return list_grad, stacked, list_grad.op.inputs[1] + + self._test_loop_fn(loop_fn, 3) + def test_create_outside_and_scatter(self): h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 7e460176c61..21419e745ad 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -3748,6 +3748,42 @@ def _convert_tensor_array_set_item(pfor_input): return wrap(_tile_variant(handle, pfor_input), True) +@RegisterPFor("TensorListPushBack") +def _convert_tensor_list_push_back(pfor_input): + handle, handle_stacked, _ = pfor_input.input(0) + tensor, tensor_stacked, _ = pfor_input.input(1) + if handle_stacked: + handle = _untile_variant(handle) + else: + handle = _stack_tensor_list(handle, tensor.dtype, + pfor_input.pfor.loop_len_vector) + if not tensor_stacked: + tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t + handle = list_ops.tensor_list_push_back(handle, tensor) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListPopBack") +def _convert_tensor_array_push_back(pfor_input): + handle = pfor_input.stacked_input(0) + element_shape = pfor_input.unstacked_input(1) + handle = _untile_variant(handle) + + if element_shape.shape.ndims == 0: + # Default / unspecified + vectorized_shape = -1 + else: + # PopBack has an element shape set when it's the gradient of PushBack, only + # used when the list is uninitialized. + vectorized_shape = array_ops.concat( + [pfor_input.pfor.loop_len_vector, element_shape], axis=0) + + output_handle, tensor = gen_list_ops.tensor_list_pop_back( + input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), + element_shape=vectorized_shape) + return wrap(output_handle, True), wrap(tensor, True) + + @RegisterPFor("TensorListConcatV2") def _convert_tensor_list_concat_v2(pfor_input): input_handle = pfor_input.stacked_input(0) @@ -4134,7 +4170,7 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs): stacked_outputs = [] for out in outputs: if not out.is_stacked: - stacked_outputs.append(_stack(out.t, array_ops.size(indices)).t) + stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) else: stacked_outputs.append(out.t) return stacked_outputs diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 309957a76a1..2934491e69a 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -510,6 +510,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python/data/ops:dataset_ops", @@ -641,6 +642,9 @@ py_test( python_version = "PY3", shard_count = 4, srcs_version = "PY2AND3", + tags = [ + "notsan", # TODO(b/170902201): Flaky + ], deps = [ ":ragged_factory_ops", ":ragged_gather_ops", diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py index e8c625ccc73..e915d1ecd61 100644 --- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py +++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py @@ -143,3 +143,42 @@ def to_sparse(rt_input, name=None): def from_sparse(st_input, name=None): return ragged_tensor.RaggedTensor.from_sparse(st_input, name) + + +@ops.RegisterGradient("RaggedTensorFromVariant") +def _ragged_tensor_from_variant_grad(op, *grads): + """Gradient for RaggedTensorFromVariant op.""" + + variant_rank = op.inputs[0].shape.rank + if variant_rank == 0: + batched_input = False + elif variant_rank == 1: + batched_input = True + elif variant_rank is None: + batched_input = (op.get_attr("output_ragged_rank") > 0) + else: + # TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so + # we can support this. + raise ValueError("Unable to compute gradient: RaggedTensorToVariant " + "can currently only generate 0D or 1D output.") + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + rt_nested_splits=op.outputs[:-1], + rt_dense_values=grads[-1], + batched_input=batched_input) + ] + + +@ops.RegisterGradient("RaggedTensorToVariant") +def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad): + """Gradient for RaggedTensorToVariant op.""" + dense_values = op.inputs[-1] + ragged_rank = len(op.inputs) - 1 + row_splits = 0 if ragged_rank == 0 else op.inputs[0] + values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient( + encoded_ragged_grad=encoded_ragged_grad, + row_splits=row_splits, + dense_values_shape=array_ops.shape(dense_values), + Tvalues=op.inputs[-1].dtype) + result = [None] * ragged_rank + [values_grad] + return result diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 5f713fa0793..0f9443cabb4 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -44,6 +44,7 @@ from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.types import internal as internal_types +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls @@ -341,6 +342,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, row_partition=row_partition) @classmethod + @dispatch.add_dispatch_support def from_value_rowids(cls, values, value_rowids, @@ -399,6 +401,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_splits(cls, values, row_splits, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_splits`. @@ -445,6 +448,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_lengths(cls, values, row_lengths, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. @@ -487,6 +491,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_starts(cls, values, row_starts, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_starts`. @@ -526,6 +531,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_limits(cls, values, row_limits, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_limits`. @@ -562,6 +568,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_uniform_row_length(cls, values, uniform_row_length, @@ -636,6 +643,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_nested_value_rowids(cls, flat_values, nested_value_rowids, @@ -692,6 +700,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return result @classmethod + @dispatch.add_dispatch_support def from_nested_row_splits(cls, flat_values, nested_row_splits, @@ -731,6 +740,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return result @classmethod + @dispatch.add_dispatch_support def from_nested_row_lengths(cls, flat_values, nested_row_lengths, @@ -1307,6 +1317,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, A `RaggedTensor`. `result.rank = 1 + new_values.rank`. `result.ragged_rank = 1 + new_values.ragged_rank` """ + new_values = _convert_to_ragged_tensor_values(new_values) new_values.shape.with_rank_at_least(1) self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) if (isinstance(new_values, RaggedTensor) and @@ -1339,8 +1350,8 @@ class RaggedTensor(composite_tensor.CompositeTensor, if isinstance(self._values, RaggedTensor): return self.with_values(self.values.with_flat_values(new_values)) else: - _assert_is_supported_ragged_values_type(new_values) - return self.with_values(new_values) + new_values = _convert_to_ragged_tensor_values(new_values) + return self.with_values(new_values) def with_row_splits_dtype(self, dtype): """Returns a copy of this RaggedTensor with the given `row_splits` dtype. @@ -1479,6 +1490,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, #============================================================================= @classmethod + @dispatch.add_dispatch_support def from_tensor(cls, tensor, lengths=None, @@ -1725,6 +1737,11 @@ class RaggedTensor(composite_tensor.CompositeTensor, if default_value is None: default_value = array_ops.zeros((), self.dtype) + if (isinstance(shape, (list, tuple)) and + any(isinstance(v, ops.Tensor) for v in shape) and + all(isinstance(v, (int, ops.Tensor)) for v in shape)): + shape = array_ops.stack(shape) + shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype) tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor( shape=shape_tensor, @@ -1751,6 +1768,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return tensor @classmethod + @dispatch.add_dispatch_support def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. @@ -2521,8 +2539,8 @@ def convert_to_tensor_or_ragged_tensor(value, return RaggedTensor.from_nested_row_splits( flat_values, value.nested_row_splits, validate=False) else: - return ops.convert_to_tensor( - value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name) + return ops.convert_to_tensor_v2_with_dispatch( + value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name) def _convert_to_ragged_tensor_values(value): @@ -2863,9 +2881,6 @@ def _get_optional_partition_dtype(values): return None -ops.no_gradient("RaggedTensorToVariant") - - _SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor) diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index d92cb9cec6c..a38c5527305 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from absl.testing import parameterized import numpy as np from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -30,8 +32,15 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_ragged_conversion_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_grad # pylint: disable=unused-import +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -1233,19 +1242,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(factory(**kwargs)) + #============================================================================= + # RaggedTensor Variant conversion + #============================================================================= -#============================================================================= -# RaggedTensor Variant conversion -#============================================================================= - - @parameterized.parameters( + @parameterized.named_parameters( { + 'testcase_name': 'Shape_5_none', 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 'ragged_rank': 1 }, { + 'testcase_name': 'Shape_4_none_2', 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 'ragged_rank': 1 }, { + 'testcase_name': 'Shape_1_none_none', 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 'ragged_rank': 2 }) @@ -1432,6 +1443,131 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): output_ragged_rank=1, input_ragged_rank=1) + def _testRaggedVarientGradient(self, func, x, expected_grad): + x = constant_op.constant(x) + if context.executing_eagerly(): + with backprop.GradientTape() as t: + t.watch(x) + y = func(x) + g = t.gradient(y, x) + else: + y = func(x) + g = gradients_impl.gradients(ys=y, xs=x)[0] + self.assertAllClose(g, expected_grad) + + def testRaggedVariantGradients(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v = rt2._to_variant(batched_input=False) + rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 100., 100., 100., 1000.]) + + def testRaggedVariantGradientsBatched(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v = rt2._to_variant(batched_input=True) + rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 100., 100., 100., 1000.]) + + def testRaggedVariantGradientsBatchedAndSliced(self): + def func(x, i): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + rt2 = rt1 * [[10], [100], [1000]] + v_slice = rt2._to_variant(batched_input=True)[i] + return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype, + output_ragged_rank=0) + + self._testRaggedVarientGradient( + functools.partial(func, i=0), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [10., 10., 10., 10., 0., 0., 0., 0.]) + self._testRaggedVarientGradient( + functools.partial(func, i=1), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0., 0., 0., 0., 100., 100., 100., 0.]) + self._testRaggedVarientGradient( + functools.partial(func, i=2), + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0., 0., 0., 0., 0., 0., 0., 1000.]) + + def testRaggedVariantGradientsRaggedRank0(self): + def func(x): + x2 = x * 2 + v = gen_ragged_conversion_ops.ragged_tensor_to_variant( + [], x2, batched_input=False) + return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0) + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) + + def testRaggedVariantGradientsRaggedRank3(self): + def func(x): + x2 = x * 2 + rt1 = RaggedTensor.from_nested_row_splits( + x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8])) + v = rt1._to_variant(batched_input=False) + rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3) + return rt3.flat_values + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) + + def testRaggedVariantGradientsViaMapFn(self): + rt = RaggedTensor.from_row_splits( + values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8]) + + def func(x): + + def transform_row(row): + return math_ops.sqrt( + math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) + + return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) + + self._testRaggedVarientGradient(func, 3.0, 14.653377) + + def testRaggedVariantGradientsViaMapFnReduce(self): + def func(x): + rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) + return map_fn.map_fn( + math_ops.reduce_max, rt1, + fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) + + self._testRaggedVarientGradient( + func, + [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) + + def testRaggedVariantGradientsErrors(self): + if context.executing_eagerly(): + return + + rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2]) + v1 = rt._to_variant() + v2 = array_ops.stack([array_ops.stack([v1])]) + y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3) + + with self.assertRaisesRegex( + ValueError, 'Unable to compute gradient: RaggedTensorToVariant ' + 'can currently only generate 0D or 1D output.'): + gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values) + def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg): """Check that two numpy arrays are equal. diff --git a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py index 28955c825e6..e01c3442cd9 100644 --- a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py @@ -720,6 +720,11 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase, return array_ops.placeholder_with_default(arg, [None] * arg.shape.rank) raise AssertionError('Unexpected shape_info %r' % shape_info) + def test_shape_is_list_including_tensor_element(self): + rt = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) + result = rt.to_tensor(shape=[2, constant_op.constant(2)]) + self.assertAllEqual(result, [[1, 2], [4, 0]]) + class RaggedToDenseBenchmark(googletest.Benchmark): diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py index c09a38f1d21..5b50cf42c56 100644 --- a/tensorflow/python/ops/structured/structured_tensor.py +++ b/tensorflow/python/ops/structured/structured_tensor.py @@ -1028,6 +1028,57 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec): self._shape[1:], dict((k, v._unbatch()) for (k, v) in self._field_specs.items())) + @property + def _flat_tensor_specs(self): + # pylint: disable=protected-access + result = [] + for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]): + result.extend(field_spec._flat_tensor_specs) + return result + + def _to_tensor_list(self, value): + return self._to_tensor_list_internal(value, batched=False) + + def _to_batched_tensor_list(self, value): + return self._to_tensor_list_internal(value, batched=True) + + def _from_compatible_tensor_list(self, tensor_list): + # pylint: disable=protected-access + fields = {} + pos = 0 + for field_name, field_spec in sorted( + self._field_specs.items(), key=lambda t: t[0]): + num_tensors_for_field = len(field_spec._flat_tensor_specs) + field_tensors = tensor_list[pos:pos + num_tensors_for_field] + fields[field_name] = field_spec._from_compatible_tensor_list( + field_tensors) + pos += num_tensors_for_field + return StructuredTensor.from_fields(fields, self._shape) + + def _to_tensor_list_internal(self, value, batched): + """Returns a dict whose entries are each field's (batched) tensor_list. + + If a field is a StructuredTensor, then its entry will be a dict, + recursively. + + Args: + value: A StructuredTensor (conforming to `self`). + batched: A boolean. if True, produce `batched_tensor_list` for each field + otherwise produce `tensor_list`. + Returns: + A dict. + """ + result = [] + for field_name, field_spec in sorted( + self._field_specs.items(), key=lambda t: t[0]): + # pylint: disable=protected-access + field_value = value._fields[field_name] + if batched: + result.extend(field_spec._to_batched_tensor_list(field_value)) + else: + result.extend(field_spec._to_tensor_list(field_value)) + + return result # Regular expression used to determine whether a string is a valid field name. # Note: we plan to relax (or possibly eliminate) this in the future; you diff --git a/tensorflow/python/ops/structured/structured_tensor_spec_test.py b/tensorflow/python/ops/structured/structured_tensor_spec_test.py index 4637a1a51e5..9cf9acf5ac3 100644 --- a/tensorflow/python/ops/structured/structured_tensor_spec_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_spec_test.py @@ -213,30 +213,56 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase, 'b': StructuredTensor.from_fields(shape=[2], fields={ 'x': [[5], [6]]})}), }, + { + 'unbatched': lambda: [ + StructuredTensor.from_fields(shape=[], fields={ + 'Ragged3d': ragged_factory_ops.constant_value([[1, 2], [3]]), + 'Ragged2d': ragged_factory_ops.constant_value([1]), + }), + StructuredTensor.from_fields(shape=[], fields={ + 'Ragged3d': ragged_factory_ops.constant_value([[1]]), + 'Ragged2d': ragged_factory_ops.constant_value([2, 3]), + })], + 'batch_size': 2, + 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ + 'Ragged3d': ragged_factory_ops.constant_value( + [[[1, 2], [3]], [[1]]]), + 'Ragged2d': ragged_factory_ops.constant_value([[1], [2, 3]]), + }), + 'use_only_batched_spec': True, + }, ]) # pyformat: disable - def testBatchUnbatchValues(self, unbatched, batch_size, batched): + def testBatchUnbatchValues(self, unbatched, batch_size, batched, + use_only_batched_spec=False): batched = batched() # Deferred init because it creates tensors. unbatched = unbatched() # Deferred init because it creates tensors. # Test batching. - unbatched_spec = type_spec.type_spec_from_value(unbatched[0]) + if use_only_batched_spec: + unbatched_spec = type_spec.type_spec_from_value(batched)._unbatch() + else: + unbatched_spec = type_spec.type_spec_from_value(unbatched[0]) unbatched_tensor_lists = [unbatched_spec._to_tensor_list(st) for st in unbatched] batched_tensor_list = [array_ops.stack(tensors) for tensors in zip(*unbatched_tensor_lists)] actual_batched = unbatched_spec._batch(batch_size)._from_tensor_list( batched_tensor_list) + self.assertTrue( + unbatched_spec._batch(batch_size).is_compatible_with(actual_batched)) self.assertAllEqual(actual_batched, batched) # Test unbatching batched_spec = type_spec.type_spec_from_value(batched) - batched_tensor_list = batched_spec._to_tensor_list(batched) + batched_tensor_list = batched_spec._to_batched_tensor_list(batched) unbatched_tensor_lists = zip( *[array_ops.unstack(tensor) for tensor in batched_tensor_list]) actual_unbatched = [ batched_spec._unbatch()._from_tensor_list(tensor_list) for tensor_list in unbatched_tensor_lists] self.assertLen(actual_unbatched, len(unbatched)) + for st in actual_unbatched: + self.assertTrue(batched_spec._unbatch().is_compatible_with(st)) for (actual, expected) in zip(actual_unbatched, unbatched): self.assertAllEqual(actual, expected) diff --git a/tensorflow/python/ops/tensor_array_ops_test.py b/tensorflow/python/ops/tensor_array_ops_test.py index 4f09ff5c22d..ec18fcd8271 100644 --- a/tensorflow/python/ops/tensor_array_ops_test.py +++ b/tensorflow/python/ops/tensor_array_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -72,6 +74,18 @@ class TensorArrayOpsTest(test.TestCase): self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']), [b'a', b'b', b'c', b'c', b'd', b'e']) + def test_init_numpy_shape(self): + @def_function.function + def fn(): + values = tensor_array_ops.TensorArray( + np.float32, + size=1, + dynamic_size=False, + element_shape=np.array((2, 3))) + values = values.write(0, np.ones((2, 3))) + return values.concat() + self.assertAllEqual(fn(), [[1., 1., 1.], [1., 1., 1.]]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/v1_compat_tests/BUILD b/tensorflow/python/ops/v1_compat_tests/BUILD index 3f44e5208a2..37bff01d429 100644 --- a/tensorflow/python/ops/v1_compat_tests/BUILD +++ b/tensorflow/python/ops/v1_compat_tests/BUILD @@ -10,7 +10,6 @@ cuda_py_test( size = "medium", srcs = ["gradient_checker_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD index 622e14616ab..9996f5c9894 100644 --- a/tensorflow/python/profiler/BUILD +++ b/tensorflow/python/profiler/BUILD @@ -36,7 +36,6 @@ cuda_py_test( srcs = ["profiler_client_test.py"], python_version = "PY3", tags = ["no_pip"], - tfrt_enabled = True, deps = [ ":profiler_client", "//tensorflow/python/eager:test", @@ -64,7 +63,6 @@ cuda_py_test( "no_pip", "no_rocm", ], - tfrt_enabled = True, deps = [ ":profiler_v2", "//tensorflow/python:constant_op", @@ -124,7 +122,6 @@ cuda_py_test( srcs = ["profiler_test.py"], python_version = "PY3", tags = ["no_pip"], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":model_analyzer", @@ -185,7 +182,6 @@ cuda_py_test( "no_gpu", # b/136036359 "no_pip", ], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":profile_context", diff --git a/tensorflow/python/profiler/integration_test/BUILD b/tensorflow/python/profiler/integration_test/BUILD index b7034ad1ddf..a7b92cd714a 100644 --- a/tensorflow/python/profiler/integration_test/BUILD +++ b/tensorflow/python/profiler/integration_test/BUILD @@ -19,7 +19,6 @@ cuda_py_test( srcs = ["profiler_api_test.py"], python_version = "PY3", tags = [ - "external", # So that test suite reruns unconditionally. "no_pip", "no_rocm", ], diff --git a/tensorflow/python/profiler/integration_test/profiler_api_test.py b/tensorflow/python/profiler/integration_test/profiler_api_test.py index a6134000a87..4e2a9dfd4e3 100644 --- a/tensorflow/python/profiler/integration_test/profiler_api_test.py +++ b/tensorflow/python/profiler/integration_test/profiler_api_test.py @@ -83,7 +83,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) - def test_single_worker_sampling_mode(self): + def test_single_worker_sampling_mode(self, delay_ms=None): """Test single worker sampling mode.""" def on_worker(port): @@ -100,6 +100,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): host_tracer_level=2, python_tracer_level=0, device_tracer_level=1, + delay_ms=delay_ms, ) profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, @@ -115,6 +116,11 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): thread_worker.join(120) self._check_xspace_pb_exist(logdir) + def test_single_worker_sampling_mode_delayed(self): + """Test single worker sampling mode with delay.""" + + self.test_single_worker_sampling_mode(delay_ms=1000) + def test_single_worker_programmatic_mode(self): """Test single worker programmatic mode.""" logdir = self.get_temp_dir() @@ -128,6 +134,7 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): _, steps, train_ds, model = _model_setup() model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) profiler.stop() + self._check_xspace_pb_exist(logdir) self._check_tools_pb_exist(logdir) diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 5adb6d0a4b1..cfc718b9d35 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -65,7 +65,6 @@ cuda_py_test( "no_gpu", # b/138442728 "no_pip", ], - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Node names are different with autojit deps = [ ":model_analyzer_testlib", @@ -120,6 +119,7 @@ tf_python_pybind_extension( ], deps = [ "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/core/profiler/convert:xplane_to_trace_events", "//tensorflow/core/profiler/lib:profiler_session_for_pybind", diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index f0c289afe01..2b513547612 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -29,6 +29,7 @@ limitations under the License. #include "pybind11/pytypes.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" @@ -47,7 +48,7 @@ namespace { using ::tensorflow::RemoteProfilerSessionManagerOptions; // Profiler gives grace after profiling duration to terminate. -constexpr absl::Duration kSessionGraceTime = absl::Seconds(5); +constexpr absl::Duration kMinSessionGraceTime = absl::Seconds(60); tensorflow::Status ValidateHostPortPair(absl::string_view host_port) { tensorflow::uint32 port; @@ -105,9 +106,9 @@ void UpdateMaxSessionDuration(RemoteProfilerSessionManagerOptions& options) { // is bounded. DCHECK_GT(local_profiler_duration, 0); VLOG(3) << "duration_ms was given as " << local_profiler_duration; - // Max session duration includes the profiling session and grace time. - auto profile_duration = - absl::Milliseconds(local_profiler_duration) + kSessionGraceTime; + // Max session duration is the profiling session with grace time. + auto profile_duration = std::max( + kMinSessionGraceTime, absl::Milliseconds(local_profiler_duration) * 2); absl::Duration delay_duration; // When requested start timestamp is 0, profiling starts immediately. if (requested_start_ts > 0) { @@ -145,6 +146,7 @@ RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir, VLOG(2) << "repository_path set to " << options.profiler_options().repository_path(); + int delay_ms = 0; for (const auto& kw : opts) { std::string key = py::cast(kw.first); if (key == "host_tracer_level") { @@ -159,11 +161,28 @@ RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir, auto value = py::cast(kw.second); options.mutable_profiler_options()->set_python_tracer_level(value); VLOG(1) << "python_tracer_level set to " << value; + } else if (key == "delay_ms") { + if (!kw.second.is_none()) { + delay_ms = py::cast(kw.second); + } } else { LOG(WARNING) << "Unrecognised key: " << key; } } + if (delay_ms) { + absl::Time start_timestamp = now + absl::Milliseconds(delay_ms); + tensorflow::int64 start_timestamp_ns = absl::ToUnixNanos(start_timestamp); + options.mutable_profiler_options()->set_start_timestamp_ns( + start_timestamp_ns); + LOG(INFO) << "delay_ms was " << delay_ms << ", start_timestamp_ns set to " + << start_timestamp_ns << " [" << start_timestamp << "]"; + } else { + DCHECK_EQ(options.mutable_profiler_options()->start_timestamp_ns(), 0); + LOG(INFO) << "Profiling will start immediately because delay_ms was unset " + "or zero."; + } + return options; } @@ -234,6 +253,7 @@ class ProfilerSessionWrapper { tensorflow::profiler::XSpace xspace; tensorflow::Status status; status = session_->CollectData(&xspace); + xspace.add_hostnames(tensorflow::port::Hostname()); session_.reset(); status = tensorflow::profiler::ExportToTensorBoard(xspace, logdir_); tensorflow::MaybeRaiseRegisteredFromStatus(status); diff --git a/tensorflow/python/profiler/profiler_v2.py b/tensorflow/python/profiler/profiler_v2.py index bcba9d52d23..102a510906b 100644 --- a/tensorflow/python/profiler/profiler_v2.py +++ b/tensorflow/python/profiler/profiler_v2.py @@ -48,9 +48,10 @@ _profiler_lock = threading.Lock() @tf_export('profiler.experimental.ProfilerOptions', v1=[]) class ProfilerOptions( - collections.namedtuple( - 'ProfilerOptions', - ['host_tracer_level', 'python_tracer_level', 'device_tracer_level'])): + collections.namedtuple('ProfilerOptions', [ + 'host_tracer_level', 'python_tracer_level', 'device_tracer_level', + 'delay_ms' + ])): """Options for finer control over the profiler. Use `tf.profiler.ProfilerOptions` to control `tf.profiler` @@ -63,15 +64,22 @@ class ProfilerOptions( - enabled, 0 - disabled [default value is 0] device_tracer_level: Adjust device (TPU/GPU) tracing level. Values are: 1 - enabled, 0 - disabled [default value is 1] + delay_ms: Requests for all hosts to start profiling at a timestamp that is + `delay_ms` away from the current time. `delay_ms` is in milliseconds. If + zero, each host will start profiling immediately upon receiving the + request. Default value is None, allowing the profiler guess the best + value. + """ def __new__(cls, host_tracer_level=2, python_tracer_level=0, - device_tracer_level=1): + device_tracer_level=1, + delay_ms=None): return super(ProfilerOptions, cls).__new__(cls, host_tracer_level, python_tracer_level, - device_tracer_level) + device_tracer_level, delay_ms) @tf_export('profiler.experimental.start', v1=[]) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index b18c7c1e738..5768cbdc15d 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -107,7 +107,6 @@ tf_py_test( name = "loader_test", size = "small", srcs = ["loader_test.py"], - tfrt_enabled = True, deps = [ ":builder", ":loader", @@ -166,7 +165,6 @@ tf_py_test( srcs = ["saved_model_test.py"], data = ["//tensorflow/cc/saved_model:saved_model_half_plus_two"], tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":builder", ":constants", @@ -217,7 +215,6 @@ tf_py_test( name = "utils_test", size = "small", srcs = ["utils_test.py"], - tfrt_enabled = True, deps = [ ":utils", "//tensorflow/core:protos_all_py", @@ -250,7 +247,6 @@ tf_py_test( name = "signature_def_utils_test", size = "small", srcs = ["signature_def_utils_test.py"], - tfrt_enabled = True, deps = [ ":signature_constants", ":signature_def_utils", @@ -266,7 +262,6 @@ tf_py_test( name = "simple_save_test", size = "small", srcs = ["simple_save_test.py"], - tfrt_enabled = True, deps = [ ":loader", ":signature_constants", @@ -311,7 +306,6 @@ tf_py_test( name = "save_context_test", srcs = ["save_context_test.py"], srcs_version = "PY2AND3", - tfrt_enabled = True, deps = [ ":save_context", ":save_options", @@ -349,6 +343,7 @@ py_strict_library( "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", + "//tensorflow/python:platform", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:tf_export", @@ -364,6 +359,7 @@ py_strict_library( "//tensorflow/python/training/tracking:base", "//tensorflow/python/training/tracking:graph_view", "//tensorflow/python/training/tracking:util", + "@absl_py//absl/logging", ], ) @@ -531,7 +527,6 @@ py_strict_library( tf_py_test( name = "revived_types_test", srcs = ["revived_types_test.py"], - tfrt_enabled = True, deps = [ ":revived_types", "//tensorflow/core:protos_all_py", @@ -567,7 +562,6 @@ py_strict_library( "//tensorflow/python:func_graph", "//tensorflow/python:function_def_to_graph", "//tensorflow/python:op_def_registry", - "//tensorflow/python:platform", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:tf_decorator", @@ -575,6 +569,7 @@ py_strict_library( "//tensorflow/python:util", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", + "@absl_py//absl/logging", ], ) @@ -606,7 +601,6 @@ py_strict_library( tf_py_test( name = "nested_structure_coder_test", srcs = ["nested_structure_coder_test.py"], - tfrt_enabled = True, deps = [ ":nested_structure_coder", "//tensorflow/core:protos_all_py", @@ -651,7 +645,6 @@ py_strict_library( tf_py_test( name = "method_name_updater_test", srcs = ["method_name_updater_test.py"], - tfrt_enabled = True, deps = [ ":method_name_updater", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 092e4177f44..4d24f7dd009 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import re +from absl import logging from tensorflow.core.framework import function_pb2 from tensorflow.core.protobuf import saved_object_graph_pb2 @@ -32,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -220,7 +220,6 @@ def recreate_function(saved_function, concrete_functions): # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. - coder = nested_structure_coder.StructureCoder() # Note: handling method functions is tricky since make_decorator does not @@ -237,7 +236,9 @@ def recreate_function(saved_function, concrete_functions): coder) def restored_function_body(*args, **kwargs): - """Calls a restored function.""" + """Calls a restored function or raises an error if no matching function.""" + if not saved_function.concrete_functions: + raise ValueError("Found zero restored functions for caller function.") # This is the format of function.graph.structured_input_signature. At this # point, the args and kwargs have already been canonicalized. inputs = (args, kwargs) diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 0c64275ce01..381fe95bff0 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -48,6 +48,7 @@ from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util @@ -119,7 +120,7 @@ class Loader(object): """Helper class to load an object-based SavedModel.""" def __init__(self, object_graph_proto, saved_model_proto, export_dir, - ckpt_options): + ckpt_options, filters): meta_graph = saved_model_proto.meta_graphs[0] self._asset_file_def = meta_graph.asset_file_def self._operation_attributes = { @@ -131,6 +132,26 @@ class Loader(object): meta_graph.graph_def.library)) self._checkpoint_options = ckpt_options + # Stores user-defined node_filters argument. + self._node_filters = filters + # Stores map of string paths to integers. + self._node_path_to_id = self._convert_node_paths_to_ints() + self._loaded_nodes = {} + if isinstance(filters, dict): + # If node_filters is a dict, then the values may contain already created + # trackable objects. In this case, create a dictionary mapping node IDs to + # the already created nodes. This dict will be updated in + # `_retrieve_all_filtered_nodes` with tracked dependencies. + for node_path, node in filters.items(): + if isinstance(node, tuple): + self._loaded_nodes[self._node_path_to_id[node_path]] = node + else: + self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr) + + # Get a list of all integer node ids to load, or None if all nodes should be + # loaded. This list includes ids of child nodes. + self._filtered_nodes = self._retrieve_all_filtered_nodes() + for name, concrete_function in self._concrete_functions.items(): # Wrap all the concrete function so that they are capable of dealing with # both in replica and cross replica cases. @@ -145,6 +166,91 @@ class Loader(object): if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) + def _convert_node_paths_to_ints(self): + """Maps all string node paths in node_filters to the int node ids.""" + if self._node_filters is None: + return None + path_to_int = {} + for node_id in self._node_filters: + int_node_id = None + if isinstance(node_id, str): + node_path = node_id.split(".") + if node_path[0] != "root": + raise ValueError( + "When passing string identifiers to node_filters, the first name" + " must be root.") + int_node_id = 0 + for n, name in enumerate(node_path[1:]): + int_node_id = self._find_node_child( + int_node_id, name, ".".join(node_path[:n+2])) + path_to_int[node_id] = int_node_id + else: + raise TypeError("Elements in node_filters must be strings.") + return path_to_int + + def _retrieve_all_filtered_nodes(self): + """Traverses through the object graph to get the IDs of all nodes to load. + + As a side-effect, if node_filters is a dictionary that contains already- + created objects, then the dependencies tracked by those objects will be + added to node_filters. + + Returns: + List of all nodes to load, or None if all nodes should be loaded. + + """ + if self._node_filters is None: + return None # All nodes should be loaded. + + all_filtered_nodes = set() + nodes_to_visit = list(self._node_filters) + + while nodes_to_visit: + node_path = nodes_to_visit.pop(0) + node_id = self._node_path_to_id[node_path] + if node_id in all_filtered_nodes: + continue + all_filtered_nodes.add(node_id) + + node, setter = self._loaded_nodes.get(node_id, (None, None)) + if node is not None: + if not isinstance(node, base.Trackable): + raise TypeError( + "Error when processing dictionary values passed to nodes_to_load." + "Object at {} is expected to be a checkpointable TensorFlow " + "object (e.g. tf.Variable, tf.Module or Keras layer)." + .format(node_path)) + node._maybe_initialize_trackable() # pylint: disable=protected-access + + for reference in self._proto.nodes[node_id].children: + child_object, _ = self._loaded_nodes.get( + reference.node_id, (None, None)) + + # See if node already tracks the child reference, in which case add the + # child to the loaded_nodes dict. + if child_object is None and node is not None: + child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access + if isinstance(child_object, data_structures.TrackableDataStructure): + # Make setattr a noop to avoid overwriting already existing data + # structures. + setter = lambda *args: None + + self._loaded_nodes[reference.node_id] = (child_object, setter) + + child_path = "{}.{}".format(node_path, reference.local_name) + self._node_path_to_id[child_path] = reference.node_id + nodes_to_visit.append(child_path) + + if 0 in all_filtered_nodes: + return None + return all_filtered_nodes + + def _find_node_child(self, node_id, child_name, path): + for reference in self._proto.nodes[node_id].children: + if reference.local_name == child_name: + return reference.node_id + raise ValueError("unable to find node {}".format(path)) + def _load_all(self): """Loads all nodes and functions from the SavedModel and their edges.""" self._load_nodes() @@ -159,7 +265,7 @@ class Loader(object): self._create_saveable_object_factories() def _create_saveable_object_factories(self): - for node_id, proto in enumerate(self._proto.nodes): + for node_id, proto in self._iter_all_nodes(): node = self.get(node_id) node._self_saveable_object_factories = {} # pylint: disable=protected-access for name, saveable_object_proto in proto.saveable_objects.items(): @@ -170,9 +276,24 @@ class Loader(object): def _load_edges(self): """Adds edges from objects to other objects and functions.""" - for node_id, object_proto in enumerate(self._proto.nodes): + for node_id, object_proto in self._iter_all_nodes(): self._add_object_graph_edges(object_proto, node_id) + # If root object isn't loaded, then create edges from the root for + # checkpoint compatibility. + if self._filtered_nodes is not None and 0 not in self._filtered_nodes: + root = self.get(0) + for node_path in self._node_filters: + loaded_node = self._nodes[self._node_path_to_id[node_path]] + path = node_path.split(".") + current_node = root + for name in path[1:-1]: + if not hasattr(current_node, name): + setattr(current_node, name, self._recreate_base_user_object()[0]) + current_node = getattr(current_node, name) + if not hasattr(current_node, path[-1]): + setattr(current_node, path[-1], loaded_node) + def _add_object_graph_edges(self, proto, node_id): """Adds edges from an object to its children.""" obj = self._nodes[node_id] @@ -214,7 +335,7 @@ class Loader(object): for name, proto in concrete_functions: concrete_function = self._concrete_functions[name] bound_inputs = [ - self._get_tensor_from_node(node_id) + self._get_tensor_from_node(node_id, name) for node_id in proto.bound_inputs] bound_variables = [ self._nodes[node_id] @@ -251,8 +372,14 @@ class Loader(object): # placeholder for this input. concrete_function.graph.capture(bound_input) - def _get_tensor_from_node(self, node_id): + def _get_tensor_from_node(self, node_id, fn_name): """Resolves a node id into a tensor to be captured for a function.""" + if self._node_filters is not None and self._nodes[node_id] is None: + raise ValueError( + "Error when processing nodes_to_load. Function \"{}\" requires " + "inputs/variables that are not loaded when nodes_to_load={}" + .format(fn_name, self._node_filters)) + with ops.init_scope(): obj = self._nodes[node_id] if distribute_utils.is_distributed_variable(obj): @@ -268,24 +395,39 @@ class Loader(object): return obj.resource_handle raise ValueError("Can't convert node %s to tensor" % (type(obj))) + def _initialize_loaded_nodes(self): + nodes = {} + node_setters = {} + for node_id, (node, setter) in self._loaded_nodes.items(): + nodes[node_id] = node + node_setters[node_id] = setter + return nodes, node_setters + + def _iter_all_nodes(self): + if self._filtered_nodes is None: + return enumerate(self._proto.nodes) + else: + return [(node_id, self._proto.nodes[node_id]) + for node_id in self._filtered_nodes] + def _load_nodes(self): """Load all saved objects.""" - # Maps from node ids to recreated objects - nodes = {} - # Maps from node ids to setter functions (same signature as setattr) for - # setting dependencies. - node_setters = {} + # `nodes` maps from node ids to recreated objects + # `node_setters` maps from node ids to setter functions + # (same signature as setattr) for setting dependencies. + nodes, node_setters = self._initialize_loaded_nodes() # Figure out which objects are slot variables. These objects are created # with Optimizer.add_slot rather than _recreate_variable. slot_variable_node_ids = set() - for proto in self._proto.nodes: + + for _, proto in self._iter_all_nodes(): for slot_variable_proto in proto.slot_variables: slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id) # Re-create everything except slot variables. - for node_id, proto in enumerate(self._proto.nodes): - if node_id in slot_variable_node_ids: + for node_id, proto in self._iter_all_nodes(): + if node_id in slot_variable_node_ids or nodes.get(node_id) is not None: # Defer recreating slot variables so we can use the public Optimizer # interface. continue @@ -295,7 +437,7 @@ class Loader(object): # Now that we have created the variables being optimized, we have enough # information to re-create slot variables for them. - for node_id, proto in enumerate(self._proto.nodes): + for node_id, proto in self._iter_all_nodes(): optimizer_object = nodes[node_id] for slot_variable_proto in proto.slot_variables: optimized_variable = nodes[ @@ -306,7 +448,13 @@ class Loader(object): nodes[slot_variable_proto.slot_variable_node_id] = slot_variable node_setters[slot_variable_proto.slot_variable_node_id] = setattr - self._nodes = [nodes[node_id] for node_id in range(len(self._proto.nodes))] + # If root object is not loaded, add a dummy root object for checkpoint + # compatibility. + if 0 not in nodes: + nodes[0] = self._recreate_base_user_object()[0] + + self._nodes = [nodes.get(node_id) + for node_id in range(len(self._proto.nodes))] self._node_setters = node_setters @property @@ -380,6 +528,8 @@ class Loader(object): return output_debug_info def get(self, node_id): + if isinstance(node_id, str): + node_id = self._node_path_to_id[node_id] return self._nodes[node_id] def _recreate(self, proto, node_id): @@ -408,7 +558,7 @@ class Loader(object): return self._recreate_base_user_object(proto, node_id) return looked_up - def _recreate_base_user_object(self, proto, node_id): + def _recreate_base_user_object(self, proto=None, node_id=None): del proto, node_id # Note: each user object has its own class. This allows making each one # individually callable by adding a `__call__` method to the classes of @@ -518,6 +668,103 @@ def _call_attribute(instance, *args, **kwargs): return instance.__call__(*args, **kwargs) +def load_partial(export_dir, filters, tags=None, options=None): + """Partially load a SavedModel (saved from V2). + + Similar to `tf.saved_model.load`, but with an additional argument that + lets you specify which nodes to load. + `tf.saved_model.load_partial(export_dir, ["root"])` and + `tf.saved_model.load(export_dir)` are equivalent. + + Note: This only works for SavedModels saved with TensorFlow V2 from + `tf.saved_model.save` or Keras. This will not load SavedModels save from + the Estimator API. + + In Tensorflow V2, SavedModel stores the **object graph** of the saved object. + The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras + layers, etc.) and edges that are the name of the attributes connecting the + objects. + + *Example 1* + + ``` + model = tf.Module() + model.child_layer = tf.Module() + model.child_layer.v = tf.Variable(5.) + tf.saved_model.save(model, '/tmp/model') + loaded = tf.__internal__.saved_model.load_partial( + ... '/tmp/model', + ... ['root.child_layer', 'root.child_layer.v']) + loaded['root.child_layer'].v.numpy() + 5. + loaded['root.child_layer'].v is loaded['root.child_layer.v'] + True + + *Example 2* + model = tf.Module() + model.child_layer = tf.Module() + model.child_layer.v = tf.Variable(5.) + >>> + tf.saved_model.save(model, '/tmp/model') + # Create a variable + new_variable = tf.Variable(0.) + loaded = tf.__internal__.saved_model.load_partial( + ... '/tmp/model', + ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) + loaded['root.child_layer'].v.numpy() + 5. + new_variable.numpy() + 5. + ``` + + **Loading under different distribution strategies** + You can load different parts of the model under different distribution + strategies. Note that this is very experimental so use with care. + + ``` + model = tf.Module() + model.layer_1 = tf.Module() + model.layer_1.v = tf.Variable(5.) + model.layer_2 = tf.Module() + model.layer_2.v = tf.Variable(7.) + tf.saved_model.save(model, '/tmp/model') + # Load with no strategy + loaded = tf.__internal__.saved_model.load_partial( + ... '/tmp/model', + ... ['root.layer_1']) + loaded['root.layer_1'].v + + strategy = tf.distribute.MirroredStrategy() + with strategy.scope(): + ... loaded2 = tf.__internal__.saved_model.load_partial( + ... '/tmp/model', + ... ['root.layer_2']) + loaded2['root.layer_2'].v + MirroredVariable:{ + 0: + } + ``` + + Args: + export_dir: The SavedModel directory to load from. + filters: A list or dictionary where each element or key is a string + path to nodes that should be loaded. Node paths consist of all the child + attribute names to reach that node in the form: `root.{attribute_name}`. + The loader will load all of the specified nodes and their recursive + descendants. When this option is defined, the loader will return a + dictionary mapping the node paths to the loaded objects. + tags: A tag or sequence of tags identifying the MetaGraph to load. Optional + if the SavedModel contains a single MetaGraph, as for those exported from + `tf.saved_model.save`. + options: `tf.saved_model.LoadOptions` object that specifies options for + loading. + + Returns: + A dictionary mapping node paths from the filter to loaded objects. + """ + return load_internal(export_dir, tags, options, filters=filters) + + @tf_export("saved_model.load", v1=["saved_model.load_v2"]) def load(export_dir, tags=None, options=None): """Load a SavedModel from `export_dir`. @@ -597,8 +844,8 @@ def load(export_dir, tags=None, options=None): tags: A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from `tf.saved_model.save`. - options: Optional, `tf.saved_model.LoadOptions` object that specifies - options for loading. + options: `tf.saved_model.LoadOptions` object that specifies options for + loading. Returns: A trackable object with a `signatures` attribute mapping from signature @@ -609,10 +856,11 @@ def load(export_dir, tags=None, options=None): Raises: ValueError: If `tags` don't match a MetaGraph in the SavedModel. """ - return load_internal(export_dir, tags, options) + return load_internal(export_dir, tags, options)["root"] -def load_internal(export_dir, tags=None, options=None, loader_cls=Loader): +def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, + filters=None): """Loader implementation.""" options = options or load_options.LoadOptions() if tags is not None and not isinstance(tags, set): @@ -639,7 +887,7 @@ def load_internal(export_dir, tags=None, options=None, loader_cls=Loader): with ops.init_scope(): try: loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, - ckpt_options) + ckpt_options, filters) except errors.NotFoundError as err: raise FileNotFoundError( str(err) + "\n If trying to load on a different device from the " @@ -654,7 +902,14 @@ def load_internal(export_dir, tags=None, options=None, loader_cls=Loader): root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) else: + if filters: + raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any " + "version) cannot be loaded with node filters.") with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) root.graph_debug_info = debug_info - return root + + if filters: + return {node_id: loader.get(node_id) for node_id in filters} + else: + return {"root": root} diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index a8244850308..5d3315b0d68 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -86,7 +86,8 @@ def cycle(obj, cycles, signatures=None): @parameterized.named_parameters( dict(testcase_name="ReloadOnce", cycles=1), dict(testcase_name="ReloadTwice", cycles=2), - dict(testcase_name="ReloadThrice", cycles=3)) + dict(testcase_name="ReloadThrice", cycles=3) +) class LoadTest(test.TestCase, parameterized.TestCase): def test_structure_import(self, cycles): @@ -2028,6 +2029,55 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase): tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)]) cycle(root, 1) + def test_load_partial_object(self): + root = module.Module() + root.variables_holder = module.Module() + root.variables_holder.v = variables.Variable(1.) + + class Adder(module.Module): + + @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=[])]) + def __call__(self, y): + root.variables_holder.v.assign_add(y) + return 1 + + root.adder = Adder() + + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, save_dir) + + imported = load.load_partial(save_dir, + ["root.variables_holder.v", "root.adder"]) + v = imported["root.variables_holder.v"] + adder = imported["root.adder"] + self.assertEqual(self.evaluate(v), 1) + adder(5) + self.assertEqual(self.evaluate(v), 6) + + with self.assertRaisesRegex(ValueError, "requires inputs/variables"): + imported = load.load_partial(save_dir, ["root.adder"]) + + def test_call_untraced_function_raises_error(self): + + class ObjWithFunction(module.Module): + + @def_function.function + def foo(self, a): + return a + + root = ObjWithFunction() + with self.assertLogs(level="WARNING") as logs: + loaded = cycle(root, 1) + + expected_save_message = ( + "WARNING:absl:Found untraced functions such as foo while saving " + "(showing 1 of 1). These functions will not be directly callable after " + "loading.") + self.assertIn(expected_save_message, logs.output) + + with self.assertRaisesRegex( + ValueError, "Found zero restored functions for caller function."): + loaded.foo(1) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index d03ae64a480..87a65724ab9 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -23,6 +23,7 @@ import functools import gc import os +from absl import logging from tensorflow.core.framework import versions_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 @@ -42,6 +43,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import function_serialization @@ -72,6 +74,9 @@ _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) _CapturedConstant = collections.namedtuple("_CapturedConstant", ["eager_tensor", "graph_tensor"]) +# Number of untraced functions to display to user in warning message. +_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 + class _AugmentedGraphView(graph_view.ObjectGraphView): """An extendable graph which also tracks functions attached to objects. @@ -185,6 +190,7 @@ class _SaveableView(object): self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() self.slot_variables = slot_variables self.concrete_functions = [] + self.untraced_functions = [] self.saveable_objects_for_node, all_saveable_functions = ( self._add_saveable_objects()) @@ -220,10 +226,20 @@ class _SaveableView(object): function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access else: concrete_functions = [function] + if not concrete_functions: + self.untraced_functions.append(function._name) + for concrete_function in concrete_functions: if concrete_function.name not in seen_function_names: seen_function_names.add(concrete_function.name) self.concrete_functions.append(concrete_function) + if self.untraced_functions: + logging.warning( + "Found untraced functions such as %s while saving (showing %d of %d)." + " These functions will not be directly callable after loading.", + ", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), + min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)), + len(self.untraced_functions)) def _add_saveable_objects(self): """Retrieves SaveablesObjects and traces their save/restore functions.""" @@ -735,14 +751,17 @@ def _serialize_object_graph(saveable_view, asset_file_def_index): if serialized is not None: proto.concrete_functions[name].CopyFrom(serialized) + saved_object_metadata = False for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): - _write_object_proto(obj, obj_proto, asset_file_def_index, - saveable_view.function_name_map) - return proto + has_saved_object_metadata = _write_object_proto( + obj, obj_proto, asset_file_def_index, saveable_view.function_name_map) + saved_object_metadata = saved_object_metadata or has_saved_object_metadata + return proto, saved_object_metadata def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): """Saves an object into SavedObject proto.""" + has_saved_object_metadata = False # The metadata field will be deprecated. if isinstance(obj, tracking.Asset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] @@ -778,11 +797,14 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): if registered_type_proto is None: # Fallback for types with no matching registration # pylint:disable=protected-access + metadata = obj._tracking_metadata + if metadata: + has_saved_object_metadata = True registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier=obj._object_identifier, version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), - metadata=obj._tracking_metadata) + metadata=metadata) # pylint:enable=protected-access proto.user_object.CopyFrom(registered_type_proto) @@ -795,6 +817,7 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): # documentation. if hasattr(obj, "_write_object_proto"): obj._write_object_proto(proto, options) # pylint: disable=protected-access + return has_saved_object_metadata def _export_debug_info(exported_graph, export_dir): @@ -992,8 +1015,7 @@ def save(obj, export_dir, signatures=None, options=None): instances with input signatures or concrete functions. Keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. - options: Optional, `tf.saved_model.SaveOptions` object that specifies - options for saving. + options: `tf.saved_model.SaveOptions` object for configuring save options. Raises: ValueError: If `obj` is not trackable. @@ -1144,11 +1166,27 @@ def _build_meta_graph_impl(obj, for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias - object_graph_proto = _serialize_object_graph(saveable_view, - asset_info.asset_index) + object_graph_proto, saved_object_metadata = _serialize_object_graph( + saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - return meta_graph_def, exported_graph, object_saver, asset_info + if saved_object_metadata: + tf_logging.warn( + 'FOR KERAS USERS: The object that you are saving contains one or more ' + 'Keras models or layers. If you are loading the SavedModel with ' + '`tf.keras.models.load_model`, continue reading (otherwise, you may ' + 'ignore the following instructions). Please change your code to save ' + 'with `tf.keras.models.save_model` or `model.save`, and confirm that ' + 'the file "keras.metadata" exists in the export directory. In the ' + 'future, Keras will only load the SavedModels that have this file. In ' + 'other words, `tf.saved_model.save` will no longer write SavedModels ' + 'that can be recovered as Keras models (this will apply in TF 2.5).' + '\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,' + ' this property has been used to save metadata in the SavedModel. The ' + 'metadta field will be deprecated soon, so please move the metadata to ' + 'a different file.') + + return (meta_graph_def, exported_graph, object_saver, asset_info) def _build_meta_graph(obj, diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index de828f7ea1b..77e19b8a8ec 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -289,6 +289,30 @@ class SaveTest(test.TestCase, parameterized.TestCase): save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(model, save_dir) + def test_save_function_no_trace(self): + + class ObjWithFunction(module.Module): + + @def_function.function + def foo(self, a): + return a + + @def_function.function + def bar(self, a): + return a + 1 + + root = ObjWithFunction() + root.bar(1) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertLogs(level="WARNING") as logs: + save.save(root, save_dir) + + expected_message = ( + "WARNING:absl:Found untraced functions such as foo while saving " + "(showing 1 of 1). These functions will not be directly callable after " + "loading.") + self.assertIn(expected_message, logs.output) + def test_find_default_save_function(self): class ObjWithDefaultSignature(util.Checkpoint): @@ -645,6 +669,23 @@ class SaveTest(test.TestCase, parameterized.TestCase): with self.assertRaises(ValueError): loader.load(session, [tag_constants.SERVING], export_dir) + def test_concrete_function_with_set_shape(self,): + # Serialized concrete function should retain the shape from the TensorSpec, + # instead of using the shape of the inputs (which are changed by set_shape). + @def_function.function + def f(x): + x.set_shape((5, 1)) + return x + + root = tracking.AutoTrackable() + path = os.path.join(self.get_temp_dir(), "saved_model") + concrete = f.get_concrete_function( + tensor_spec.TensorSpec((None, 1), name="name")) + save.save(root, path, signatures={"key": concrete}) + imported = load.load(path) + self.assertEqual(imported.signatures["key"].structured_input_signature[1], + {"name": tensor_spec.TensorSpec((None, 1), name="name")}) + class VariablePolicyEnumTest(test.TestCase): diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 74f76c690f2..14f4df81380 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -124,15 +124,26 @@ def canonicalize_signatures(signatures): structured_outputs = signature_function(**kwargs) return _normalize_outputs( structured_outputs, signature_function.name, signature_key) - # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names - # always match keyword arguments. tensor_spec_signature = {} - for keyword, tensor in zip( + if signature_function.structured_input_signature is not None: + # The structured input signature may contain other non-tensor arguments. + inputs = filter( + lambda x: isinstance(x, tensor_spec.TensorSpec), + nest.flatten(signature_function.structured_input_signature, + expand_composites=True)) + else: + # Structured input signature isn't always defined for some functions. + inputs = signature_function.inputs + + for keyword, inp in zip( signature_function._arg_keywords, # pylint: disable=protected-access - signature_function.inputs): + inputs): keyword = compat.as_str(keyword) - tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor( - tensor, name=keyword) + if isinstance(inp, tensor_spec.TensorSpec): + spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) + else: + spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) + tensor_spec_signature[keyword] = spec final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature) # pylint: disable=protected-access diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 36165deeaad..1c46c228cf0 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -580,10 +580,17 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // MLIR Logic m.def("TF_IsMlirBridgeEnabled", [] { - return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + // Since python protobuf enums are integers, cast to an integer before + // returning the enum to python. + return static_cast( + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge); }); m.def("TF_EnableMlirBridge", [](bool enabled) { - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled; + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = + enabled + ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED + : tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_DISABLED; }); m.def("TF_EnableXlaDevices", [] { tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; @@ -1047,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) { TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); }); m.def("TFE_CollectiveOpsCheckPeerHealth", - [](const py::handle& ctx, const char* task) { + [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), - task, status.get()); + task, timeout_in_ms, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 049594ead90..7b1f85dc0e9 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -368,6 +368,74 @@ py_test( ], ) +py_binary( + name = "make_aot_compile_models", + srcs = ["make_aot_compile_models.py"], + python_version = "PY3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/saved_model", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@six_archive//:six", + ], +) + +EMITTED_AOT_SAVE_MODEL_OBJECTS = [ + "x_matmul_y_large/saved_model.pb", + "x_matmul_y_large/variables/variables.index", + "x_matmul_y_small/saved_model.pb", + "x_matmul_y_small/variables/variables.index", +] + +genrule( + name = "create_models_for_aot_compile", + outs = EMITTED_AOT_SAVE_MODEL_OBJECTS, + cmd = ( + "$(location :make_aot_compile_models) --out_dir $(@D)" + ), + exec_tools = [":make_aot_compile_models"], + tags = ["no_rocm"], +) + +filegroup( + name = "aot_saved_models", + srcs = EMITTED_AOT_SAVE_MODEL_OBJECTS, +) + +saved_model_compile_aot( + name = "aot_compiled_x_matmul_y_large", + cpp_class = "XMatmulYLarge", + directory = "//tensorflow/python/tools:x_matmul_y_large", + filegroups = [":aot_saved_models"], + force_without_xla_support_flag = False, + tags = ["no_rocm"], +) + +saved_model_compile_aot( + name = "aot_compiled_x_matmul_y_large_multithreaded", + cpp_class = "XMatmulYLargeMultithreaded", + directory = "//tensorflow/python/tools:x_matmul_y_large", + filegroups = [":aot_saved_models"], + force_without_xla_support_flag = False, + multithreading = True, + tags = ["no_rocm"], +) + +saved_model_compile_aot( + name = "aot_compiled_x_matmul_y_small", + cpp_class = "XMatmulYSmall", + directory = "//tensorflow/python/tools:x_matmul_y_small", + filegroups = [":aot_saved_models"], + force_without_xla_support_flag = False, + tags = ["no_rocm"], +) + saved_model_compile_aot( name = "aot_compiled_x_plus_y", cpp_class = "XPlusY", @@ -402,6 +470,32 @@ saved_model_compile_aot( variables_to_feed = "variable_x", ) +sh_test( + name = "large_matmul_no_multithread_test", + srcs = if_xla_available( + ["no_xla_multithread_symbols_test.sh"], + if_false = ["skip_test.sh"], + ), + args = if_xla_available(["$(location :aot_compiled_x_matmul_y_large.o)"]), + data = if_xla_available([":aot_compiled_x_matmul_y_large.o"]), +) + +sh_test( + name = "large_matmul_yes_multithread_test", + srcs = if_xla_available( + [ + "xla_multithread_symbols_test.sh", + ], + if_false = ["skip_test.sh"], + ), + args = if_xla_available( + ["$(location :aot_compiled_x_matmul_y_large_multithreaded.o)"], + ), + data = if_xla_available( + [":aot_compiled_x_matmul_y_large_multithreaded.o"], + ), +) + tf_cc_test( name = "aot_compiled_test", srcs = if_xla_available([ @@ -413,8 +507,12 @@ tf_cc_test( ] + if_xla_available([ ":aot_compiled_vars_and_arithmetic", ":aot_compiled_vars_and_arithmetic_frozen", + ":aot_compiled_x_matmul_y_large", + ":aot_compiled_x_matmul_y_large_multithreaded", + ":aot_compiled_x_matmul_y_small", ":aot_compiled_x_plus_y", "//tensorflow/core:test", + "//third_party/eigen3", "//tensorflow/core/platform:logging", ]), ) diff --git a/tensorflow/python/tools/aot_compiled_test.cc b/tensorflow/python/tools/aot_compiled_test.cc index 3e8084590db..0c15e638841 100644 --- a/tensorflow/python/tools/aot_compiled_test.cc +++ b/tensorflow/python/tools/aot_compiled_test.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h" #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h" +#include "tensorflow/python/tools/aot_compiled_x_matmul_y_large.h" +#include "tensorflow/python/tools/aot_compiled_x_matmul_y_large_multithreaded.h" +#include "tensorflow/python/tools/aot_compiled_x_matmul_y_small.h" #include "tensorflow/python/tools/aot_compiled_x_plus_y.h" namespace tensorflow { @@ -30,6 +36,97 @@ TEST(AOTCompiledSavedModelTest, XPlusY) { ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f); } +TEST(AOTCompiledSavedModelTest, XMatmulYLarge) { + XMatmulYLarge model; + // Calculation is: output_0 = x @ y. + EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000); + EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000); + EXPECT_EQ(model.result0_count(), 3000 * 4000); + + Eigen::Tensor arg_feed_x(3000, 5000); + Eigen::Tensor arg_feed_y(5000, 4000); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); + + // Set up dimensions for standard matmul. + const Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + // Ground truth matmul. + const Eigen::Tensor expected_output0 = + arg_feed_x.contract(arg_feed_y, product_dims); + + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); + CHECK(model.Run()); + EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), + /*abs_error=*/1e-6f); + EXPECT_NEAR(model.result_fetch_output_0(2999, 3999), + expected_output0(2999, 3999), + /*abs_error=*/1e-6f); +} + +TEST(AOTCompiledSavedModelTest, XMatmulYLargeMultithreaded) { + XMatmulYLargeMultithreaded model; + + Eigen::ThreadPool pool(2); + Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + model.set_thread_pool(&device); + + // Calculation is: output_0 = x @ y. + EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000); + EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000); + EXPECT_EQ(model.result0_count(), 3000 * 4000); + + Eigen::Tensor arg_feed_x(3000, 5000); + Eigen::Tensor arg_feed_y(5000, 4000); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); + + // Set up dimensions for standard matmul. + const Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + // Ground truth matmul. + const Eigen::Tensor expected_output0 = + arg_feed_x.contract(arg_feed_y, product_dims); + + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); + CHECK(model.Run()); + EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), + /*abs_error=*/1e-3f); + EXPECT_NEAR(model.result_fetch_output_0(2999, 3999), + expected_output0(2999, 3999), + /*abs_error=*/1e-3f); +} + +TEST(AOTCompiledSavedModelTest, XMatmulYSmall) { + XMatmulYSmall model; + // Calculation is: output_0 = x @ y. + EXPECT_EQ(model.arg_feed_x_count(), 3 * 5); + EXPECT_EQ(model.arg_feed_y_count(), 5 * 4); + EXPECT_EQ(model.result0_count(), 3 * 4); + + Eigen::Tensor arg_feed_x(3, 5); + Eigen::Tensor arg_feed_y(5, 4); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); + + // Set up dimensions for standard matmul. + const Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + // Ground truth matmul. + const Eigen::Tensor expected_output0 = + arg_feed_x.contract(arg_feed_y, product_dims); + + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); + CHECK(model.Run()); + EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), + /*abs_error=*/1e-6f); + EXPECT_NEAR(model.result_fetch_output_0(2, 3), expected_output0(2, 3), + /*abs_error=*/1e-6f); +} + TEST(AOTCompiledSavedModelTest, VarsAndArithmetic) { VarsAndArithmeticFrozen frozen_model; // Calculation is: diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index c04ab6900e9..eccd39d3f4e 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -31,6 +31,8 @@ TENSORFLOW_API_INIT_FILES = [ "distribute/__init__.py", "distribute/cluster_resolver/__init__.py", "distribute/experimental/__init__.py", + "distribute/experimental/coordinator/__init__.py", + "distribute/experimental/partitioners/__init__.py", "dtypes/__init__.py", "errors/__init__.py", "experimental/__init__.py", diff --git a/tensorflow/python/tools/api/generator/output_init_files_test.py b/tensorflow/python/tools/api/generator/output_init_files_test.py index 25035e0567a..37589b0ced4 100644 --- a/tensorflow/python/tools/api/generator/output_init_files_test.py +++ b/tensorflow/python/tools/api/generator/output_init_files_test.py @@ -27,6 +27,8 @@ from tensorflow.lite.python import lite as _tflite_for_api_traversal from tensorflow.python import modules_with_exports from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import parameter_server_strategy_v2 +from tensorflow.python.distribute.coordinator import cluster_coordinator from tensorflow.python.framework import combinations from tensorflow.python.framework import test_combinations # pylint: enable=unused-import diff --git a/tensorflow/python/tools/make_aot_compile_models.py b/tensorflow/python/tools/make_aot_compile_models.py new file mode 100644 index 00000000000..2a8f3550472 --- /dev/null +++ b/tensorflow/python/tools/make_aot_compile_models.py @@ -0,0 +1,82 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate some SavedModels for use by AOT compilation tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import flags + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import app +from tensorflow.python.saved_model import save +from tensorflow.python.training.tracking import tracking + + +flags.DEFINE_string('out_dir', None, + 'Directory to output saved models to.') + +FLAGS = flags.FLAGS + + +def create_large_matmul_savedmodel(out_dir): + """Create a SavedModel that performs a large matmul.""" + root = tracking.AutoTrackable() + root.f = def_function.function( + lambda x, y: math_ops.matmul(x, y), # pylint: disable=unnecessary-lambda + input_signature=[tensor_spec.TensorSpec([3000, 5000], dtypes.float32), + tensor_spec.TensorSpec([5000, 4000], dtypes.float32),]) + root.f(x=array_ops.zeros((3000, 5000)), + y=array_ops.zeros((5000, 4000))) + save_dir = os.path.join(out_dir, 'x_matmul_y_large') + save.save(root, save_dir, root.f) + # This simple SavedModel lacks any variables, but we need to create a + # variables.index file to make bazel genrule happy. + with open(os.path.join(save_dir, 'variables', 'variables.index'), 'w'): + pass + + +def create_small_matmul_savedmodel(out_dir): + """Create a SavedModel that performs a small matmul.""" + root = tracking.AutoTrackable() + root.f = def_function.function( + lambda x, y: math_ops.matmul(x, y), # pylint: disable=unnecessary-lambda + input_signature=[tensor_spec.TensorSpec([3, 5], dtypes.float32), + tensor_spec.TensorSpec([5, 4], dtypes.float32),]) + root.f(x=array_ops.zeros((3, 5)), + y=array_ops.zeros((5, 4))) + save_dir = os.path.join(out_dir, 'x_matmul_y_small') + save.save(root, save_dir, root.f) + # This simple SavedModel lacks any variables, but we need to create a + # variables.index file to make bazel genrule happy. + with open(os.path.join(save_dir, 'variables', 'variables.index'), 'w'): + pass + + +def main(unused_args): + create_small_matmul_savedmodel(FLAGS.out_dir) + create_large_matmul_savedmodel(FLAGS.out_dir) + + +if __name__ == '__main__': + flags.mark_flag_as_required('out_dir') + app.run(main) diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_pip_rename.sh b/tensorflow/python/tools/no_xla_multithread_symbols_test.sh old mode 100644 new mode 100755 similarity index 66% rename from tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_pip_rename.sh rename to tensorflow/python/tools/no_xla_multithread_symbols_test.sh index 039f9516d86..468c283ad98 --- a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_pip_rename.sh +++ b/tensorflow/python/tools/no_xla_multithread_symbols_test.sh @@ -1,5 +1,4 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +13,15 @@ # limitations under the License. # ============================================================================== set -e -set -x -source tensorflow/tools/ci_build/release/common.sh - -# Copy and rename to tensorflow -for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do - copy_to_new_project_name "${f}" tensorflow_gpu -done +SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime) +if echo "${SYMBOLS}" | grep -q SingleThread; then + exit 0 +else + echo "" 1>&2 + echo "Did not see SingleThread runtime symbol in $@:" 1>&2 + echo "" 1>&2 + echo "${SYMBOLS}" 1>&2 + echo "" 1>&2 + exit 1 +fi diff --git a/tensorflow/python/tools/saved_model_aot_compile.py b/tensorflow/python/tools/saved_model_aot_compile.py index bf955ad825c..d1478e205d3 100644 --- a/tensorflow/python/tools/saved_model_aot_compile.py +++ b/tensorflow/python/tools/saved_model_aot_compile.py @@ -19,11 +19,10 @@ from __future__ import division from __future__ import print_function import collections - import copy -import hashlib import os import pipes +import re import shlex import six @@ -217,7 +216,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, target_triple, target_cpu, variables_to_feed=(), - enable_multithreading=False): + multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and @@ -245,8 +244,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. - enable_multithreading: Not implemented. Enable multithreading in the - compiled computation. + multithreading: Whether to enable multithreading in the compiled + computation. Note that if using this option, the resulting object files + may have external dependencies on multithreading libraries like nsync. Raises: RuntimeError: If tensorflow was not built with XLA. @@ -254,23 +254,20 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. - NotImplementedError: If `enable_multithreading is True`. """ if _pywrap_tfcompile_import_error: - raise _pywrap_tfcompile_import_error + raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type - if enable_multithreading: - raise NotImplementedError( - 'Multithreading is not currently supported because it requires ' - 'additional dependencies in the AOT runtime.') else: # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap # so that we can set these directly instead of relying on env vars. xla_flags = os.environ.get('XLA_FLAGS') if not xla_flags: - xla_flags = '--xla_cpu_multi_thread_eigen=false' + xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( + 'true' if multithreading else 'false') else: - xla_flags += ',--xla_cpu_multi_thread_eigen=false' + xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format( + 'true' if multithreading else 'false') os.environ['XLA_FLAGS'] = xla_flags signature_def_map = meta_graph_def.signature_def @@ -352,10 +349,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, output_dir = os.path.dirname(output_prefix) file_io.recursive_create_dir(output_dir) - entry_digest = hashlib.md5() - entry_digest.update(str(config).encode()) - entry_digest.update(str(graph_def).encode()) - entry_digest = entry_digest.hexdigest() + entry_point = re.sub( + '[^0-9a-zA-Z]+', '_', + '__xla_' + output_prefix + '__' + cpp_class) logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) @@ -371,7 +367,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, cpp_class=cpp_class, target_triple=target_triple, target_cpu=target_cpu, - entry_point='entry_{}'.format(entry_digest), + entry_point=entry_point, out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix), diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 0c8b8f5576b..124686dff13 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -821,6 +821,7 @@ def aot_compile_cpu(args): variables_to_feed = None # We will identify them after. else: variables_to_feed = args.variables_to_feed.split(',') + saved_model_aot_compile.aot_compile_cpu_meta_graph_def( checkpoint_path=checkpoint_path, meta_graph_def=saved_model_utils.get_meta_graph_def( @@ -831,7 +832,7 @@ def aot_compile_cpu(args): target_triple=args.target_triple, target_cpu=args.target_cpu, cpp_class=args.cpp_class, - enable_multithreading=args.enable_multithreading) + multithreading=args.multithreading.lower() not in ('f', 'false', '0')) def add_show_subparser(subparsers): @@ -1140,11 +1141,13 @@ def add_aot_compile_cpu_subparser(subparsers): '(this applies to all input arguments from the signature as ' 'well).')) parser_compile.add_argument( - '--enable_multithreading', - type=bool, - default='', - help=('*NOT CURRENTLY SUPPORTED* ' - 'Enable multithreading in the compiled computation.')) + '--multithreading', + type=str, + default='False', + help=('Enable multithreading in the compiled computation. ' + 'Note that if using this option, the resulting object files ' + 'may have external dependencies on multithreading libraries ' + 'like nsync.')) parser_compile.set_defaults(func=aot_compile_cpu) diff --git a/tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/release.sh b/tensorflow/python/tools/skip_test.sh old mode 100644 new mode 100755 similarity index 72% rename from tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/release.sh rename to tensorflow/python/tools/skip_test.sh index ccc80e1bafd..5c9407175fe --- a/tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/release.sh +++ b/tensorflow/python/tools/skip_test.sh @@ -1,5 +1,4 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -set -e -set -x - -# Install latest bazel -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk -tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh +exit 0 diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index 79f771bbcad..db886746006 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -21,6 +21,7 @@ def saved_model_compile_aot( variables_to_feed = "", target_triple = None, target_cpu = None, + multithreading = False, force_without_xla_support_flag = True, tags = None): """Compile a SavedModel directory accessible from a filegroup. @@ -93,6 +94,11 @@ def saved_model_compile_aot( target architecture's triple). Similar to clang's -target flag. target_cpu: The LLVM cpu name used for compilation. Similar to clang's -mcpu flag. + multithreading: Whether to compile multithreaded AOT code. + Note, this increases the set of dependencies for binaries using + the AOT library at both build and runtime. For example, + the resulting object files may have external dependencies on + multithreading libraries like nsync. force_without_xla_support_flag: Whether to compile even when `--define=with_xla_support=true` is not set. If `False`, and the define is not passed when building, then the created `cc_library` @@ -135,6 +141,7 @@ def saved_model_compile_aot( "--cpp_class {} ".format(cpp_class) + "--variables_to_feed {} ".format(variables_to_feed) + "--signature_def_key {} ".format(signature_def) + + "--multithreading {} ".format(multithreading) + "--target_triple " + target_triple + " " + ("--target_cpu " + target_cpu + " " if target_cpu else "") + "--tag_set {} ".format(tag_set) diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh b/tensorflow/python/tools/xla_multithread_symbols_test.sh old mode 100644 new mode 100755 similarity index 67% rename from tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh rename to tensorflow/python/tools/xla_multithread_symbols_test.sh index 43982623109..9576c762112 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_pip_rename.sh +++ b/tensorflow/python/tools/xla_multithread_symbols_test.sh @@ -1,5 +1,4 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +13,15 @@ # limitations under the License. # ============================================================================== set -e -set -x -source tensorflow/tools/ci_build/release/common.sh - -# Rename to tensorflow_cpu -for f in $(ls py_test_dir/tensorflow-*cp3*-cp3*m-win_amd64.whl); do - copy_to_new_project_name "${f}" tensorflow_cpu - rm "${f}" -done +SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime) +if echo "${SYMBOLS}" | grep -q SingleThread; then + echo "" 1>&2 + echo "Saw a SingleThread runtime symbol in $@:" 1>&2 + echo "" 1>&2 + echo "${SYMBOLS}" 1>&2 + echo "" 1>&2 + exit 1 +else + exit 0 +fi diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 0afe2086d93..22a77fc89b0 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -337,7 +337,6 @@ tf_py_test( "no_oss", # TODO(b/131157871): Reenable in OSS when fixed "no_windows", # TODO: needs investigation on Windows ], - tfrt_enabled = True, deps = [ ":tpu", "//tensorflow/python:client_testlib", @@ -351,7 +350,6 @@ tf_py_test( name = "tpu_sharding_test", size = "small", srcs = ["tpu_sharding_test.py"], - tfrt_enabled = True, deps = [ ":tpu", "//tensorflow/python:client_testlib", @@ -363,7 +361,6 @@ tf_py_test( name = "bfloat16_test", size = "small", srcs = ["bfloat16_test.py"], - tfrt_enabled = True, deps = [ ":tpu", "//tensorflow/python:client_testlib", @@ -375,7 +372,6 @@ tf_py_test( name = "tpu_infeed_test", size = "small", srcs = ["tpu_infeed_test.py"], - tfrt_enabled = True, deps = [ ":tpu", "//tensorflow/python:framework", @@ -387,7 +383,6 @@ tf_py_test( name = "topology_test", size = "medium", srcs = ["topology_test.py"], - tfrt_enabled = True, deps = [ ":tpu", "//tensorflow/python:framework_test_lib", @@ -464,7 +459,6 @@ tf_py_test( "feature_column_test.py", ], main = "feature_column_test.py", - tfrt_enabled = True, deps = [ ":feature_column", "//tensorflow/python:client_testlib", @@ -487,7 +481,6 @@ tf_py_test( "feature_column_v2_test.py", ], main = "feature_column_v2_test.py", - tfrt_enabled = True, deps = [ ":feature_column_v2", "//tensorflow/python:client_testlib", @@ -622,7 +615,6 @@ tf_py_test( ], python_version = "PY3", srcs_version = "PY2AND3", - tfrt_enabled = True, deps = [ ":tpu_embedding_v2", "//tensorflow/python/compat:v2_compat", diff --git a/tensorflow/python/tpu/client/BUILD b/tensorflow/python/tpu/client/BUILD index a6973d4ec22..dc94bffb64e 100644 --- a/tensorflow/python/tpu/client/BUILD +++ b/tensorflow/python/tpu/client/BUILD @@ -43,7 +43,6 @@ tf_py_test( tags = [ "no_oss_py2", ], - tfrt_enabled = True, deps = [ ":client", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index c6b5a256b42..084ec1f3dba 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -288,6 +288,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = None self._oc_dev_fn_stack = None self._outside_compilation_cluster = None + self._outside_compilation_v2_context = None self._outside_compilation_counter = 0 self._in_gradient_colocation = None self._gradient_colocation_stack = [] @@ -379,6 +380,21 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def EnterGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # If we are in TF 2 functions (control flow V2 functions, or + # tf.function()), we need to attach _xla_outside_compilation attribute + # directly because we are not in TPUReplicateContext. + try: + outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") + except ValueError: + # The attr was not present: do nothing. + return + parts = outside_attr.split(".") + cluster = parts[0] + "." + gradient_uid + self._outside_compilation_v2_context = OutsideCompilationV2Context( + cluster) + self._outside_compilation_v2_context.Enter() + return self._gradient_colocation_stack.append(op) if not self._outside_compilation_cluster: try: @@ -418,6 +434,17 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def ExitGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # Inside a TF2 tf.function or control flow graph and `op` was not + # marked to be outside compiled. + assert self._outside_compilation_v2_context is None + return + if self._outside_compilation_v2_context is not None: + # Inside a TF2 tf.function or control flow graph and `op` was + # marked to be outside compiled. + self._outside_compilation_v2_context.Exit() + self._outside_compilation_v2_context = None + return if not self._gradient_colocation_stack: raise errors.InternalError( op.node_def, op, diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index 3a45a579631..7c42bb2c41f 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -370,6 +370,8 @@ class _OptimizationParameters(object): clip_weight_max: Optional[float], weight_decay_factor: Optional[float], multiply_weight_decay_factor_by_learning_rate: Optional[bool], + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): self.learning_rate = learning_rate self.use_gradient_accumulation = use_gradient_accumulation @@ -378,6 +380,8 @@ class _OptimizationParameters(object): self.weight_decay_factor = weight_decay_factor self.multiply_weight_decay_factor_by_learning_rate = ( multiply_weight_decay_factor_by_learning_rate) + self.clip_gradient_min = clip_gradient_min + self.clip_gradient_max = clip_gradient_max @tf_export(v1=['tpu.experimental.AdagradParameters']) @@ -409,6 +413,8 @@ class AdagradParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for Adagrad. @@ -425,11 +431,20 @@ class AdagradParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(AdagradParameters, - self).__init__(learning_rate, use_gradient_accumulation, - clip_weight_min, clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(AdagradParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) if initial_accumulator <= 0: raise ValueError('Adagrad initial_accumulator must be positive') self.initial_accumulator = initial_accumulator @@ -455,6 +470,8 @@ class ProximalAdagradParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for Adagrad. @@ -474,11 +491,20 @@ class ProximalAdagradParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(ProximalAdagradParameters, - self).__init__(learning_rate, use_gradient_accumulation, - clip_weight_min, clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(ProximalAdagradParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) if initial_accumulator <= 0: raise ValueError('Adagrad initial_accumulator must be positive') if l1_regularization_strength < 0.: @@ -527,6 +553,8 @@ class AdamParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for Adam. @@ -551,11 +579,20 @@ class AdamParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(AdamParameters, - self).__init__(learning_rate, use_gradient_accumulation, - clip_weight_min, clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(AdamParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) if beta1 < 0. or beta1 >= 1.: raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) if beta2 < 0. or beta2 >= 1.: @@ -608,6 +645,8 @@ class FtrlParameters(_OptimizationParameters): multiply_linear_by_learning_rate: bool = False, beta: float = 0, allow_zero_accumulator: bool = False, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for Ftrl. @@ -644,11 +683,20 @@ class FtrlParameters(_OptimizationParameters): allow_zero_accumulator: Changes the implementation of the square root to allow for the case of initial_accumulator_value being zero. This will cause a slight performance drop. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(FtrlParameters, - self).__init__(learning_rate, use_gradient_accumulation, - clip_weight_min, clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(FtrlParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) if learning_rate_power > 0.: raise ValueError('learning_rate_power must be less than or equal to 0. ' 'got {}.'.format(learning_rate_power)) @@ -703,6 +751,8 @@ class ProximalYogiParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for Proximal Yogi. @@ -728,11 +778,20 @@ class ProximalYogiParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(ProximalYogiParameters, - self).__init__(learning_rate, use_gradient_accumulation, - clip_weight_min, clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(ProximalYogiParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) if beta1 < 0. or beta1 >= 1.: raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) if beta2 < 0. or beta2 >= 1.: @@ -783,6 +842,8 @@ class MomentumParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for momentum. @@ -807,6 +868,8 @@ class MomentumParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ super(MomentumParameters, self).__init__( learning_rate=learning_rate, @@ -816,6 +879,8 @@ class MomentumParameters(_OptimizationParameters): weight_decay_factor=weight_decay_factor, multiply_weight_decay_factor_by_learning_rate=( multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, ) self.momentum = momentum self.use_nesterov = use_nesterov @@ -851,6 +916,8 @@ class RMSPropParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for RMS prop. @@ -868,6 +935,8 @@ class RMSPropParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ super(RMSPropParameters, self).__init__( learning_rate=learning_rate, @@ -877,6 +946,8 @@ class RMSPropParameters(_OptimizationParameters): weight_decay_factor=weight_decay_factor, multiply_weight_decay_factor_by_learning_rate=( multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, ) self.rho = rho self.momentum = momentum @@ -910,6 +981,8 @@ class StochasticGradientDescentParameters(_OptimizationParameters): clip_weight_max: Optional[float] = None, weight_decay_factor: Optional[float] = None, multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + clip_gradient_min: Optional[float] = None, + clip_gradient_max: Optional[float] = None, ): """Optimization parameters for stochastic gradient descent. @@ -921,11 +994,20 @@ class StochasticGradientDescentParameters(_OptimizationParameters): weights are not decayed. multiply_weight_decay_factor_by_learning_rate: if true, `weight_decay_factor` is multiplied by the current learning rate. + clip_gradient_min: the minimum value to clip by; None means -infinity. + clip_gradient_max: the maximum value to clip by; None means +infinity. """ - super(StochasticGradientDescentParameters, - self).__init__(learning_rate, False, clip_weight_min, clip_weight_max, - weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate) + super(StochasticGradientDescentParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=False, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + clip_gradient_min=clip_gradient_min, + clip_gradient_max=clip_gradient_max, + ) DeviceConfig = collections.namedtuple('DeviceConfig', @@ -1285,6 +1367,14 @@ class TPUEmbedding(object): optimization_parameters_pb2.GradientAccumulationStatus.ENABLED if optimization_parameters.use_gradient_accumulation else optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) + + if optimization_parameters.clip_gradient_min is not None: + parameters.gradient_clipping_limits.lower.value = ( + optimization_parameters.clip_gradient_min) + if optimization_parameters.clip_gradient_max is not None: + parameters.gradient_clipping_limits.upper.value = ( + optimization_parameters.clip_gradient_max) + if optimization_parameters.clip_weight_min is not None: parameters.clipping_limits.lower.value = ( optimization_parameters.clip_weight_min) diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 1417824e53c..413c6eb2264 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -67,7 +67,7 @@ _NAME_KEY = "_tpu_embedding_layer" # sharded variables that can be used in the PSStrategy with optimizers. # We implement just enough of the of a tf.Variable so that this could be passed # to an optimizer. -class TPUShardedVariable(sharded_variable.ShardedVariable): +class TPUShardedVariable(sharded_variable.ShardedVariableMixin): """A ShardedVariable class for TPU.""" @property @@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable): self._config_proto = self._create_config_proto() - logging.info("Initializing TPU Embedding engine with config: %s", - self._config_proto) + logging.info("Initializing TPU Embedding engine.") + tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) + @def_function.function def load_config(): tpu.initialize_system_for_tpu_embedding(self._config_proto) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index e04f1f0281a..33ff73ed706 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -23,9 +23,12 @@ import abc import math import typing from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union + +from absl import logging import six from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 +from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.distribute import sharded_variable from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops_v2 @@ -731,3 +734,18 @@ class FeatureConfig(object): max_sequence_length=self.max_sequence_length, name=self.name) ) + + +def log_tpu_embedding_configuration( + config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None: + """Logs a TPUEmbeddingConfiguration proto across multiple statements. + + Args: + config: TPUEmbeddingConfiguration proto to log. Necessary because + logging.info has a maximum length to each log statement, which + particularly large configs can exceed. + """ + logging.info("Beginning log of TPUEmbeddingConfiguration.") + for line in str(config).splitlines(): + logging.info(line) + logging.info("Done with log of TPUEmbeddingConfiguration.") diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py index 48797b00009..770ca1fc407 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.compat import v2_compat from tensorflow.python.platform import test from tensorflow.python.tpu import tpu_embedding_v2_utils @@ -88,6 +89,34 @@ class ConfigTest(test.TestCase): ) +class TPUEmbeddingConfigurationTest(test.TestCase): + + def test_no_truncate(self): + truncate_length = 14937 # Experimentally maximum string length loggable. + + config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() + for i in range(500): + td = config.table_descriptor.add() + td.name = 'table_{}'.format(i) + td.vocabulary_size = i + config.num_hosts = 2 + config.num_tensor_cores = 4 + config.batch_size_per_tensor_core = 128 + + self.assertGreater( + len(str(config)), truncate_length, + 'Test sanity check: generated config should be of truncating length.') + + with self.assertLogs() as logs: + tpu_embedding_v2_utils.log_tpu_embedding_configuration(config) + + self.assertIn('table_499', ''.join(logs.output)) + for line in logs.output: + self.assertLess( + len(line), truncate_length, + 'Logging function lines should not be of truncating length.') + + if __name__ == '__main__': v2_compat.enable_v2_behavior() test.main() diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 4eb6429f3c8..30bfabdff7c 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops @@ -450,6 +451,36 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy.experimental_local_results(train_step()), constant_op.constant(2916., shape=(strategy.num_replicas_in_sync))) + def testColocateGradientWithOutsideCompiledOp(self): + strategy = get_tpu_strategy() + + @def_function.function + def train_step(): + + @def_function.function + def tpu_fn(x): + x1 = tpu.outside_compilation(math_ops.sqrt, x) + grad = gradients_impl.gradients([x1], [x], + colocate_gradients_with_ops=True)[0] + sqrt = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "Sqrt" + ][0] + sqrt_grad = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "SqrtGrad" + ][0] + assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0" + assert (sqrt_grad.get_attr( + tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid") + return grad + + return strategy.run(tpu_fn, args=(25.0,)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(.1, shape=(strategy.num_replicas_in_sync))) + class OutsideCompilationOnUnsupportedOpTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index cf2d89b0d1f..ca961f67cab 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -602,7 +602,6 @@ tf_py_test( tags = [ "noasan", # TODO(b/161236904): flaky timeout in trying to start gRPC server ], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -624,7 +623,6 @@ tf_py_test( srcs = ["server_lib_multiple_containers_test.py"], grpc_enabled = True, python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -645,7 +643,6 @@ tf_py_test( srcs = ["server_lib_same_variables_clear_container_test.py"], grpc_enabled = True, python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -666,7 +663,6 @@ tf_py_test( srcs = ["server_lib_same_variables_clear_test.py"], grpc_enabled = True, python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -687,7 +683,6 @@ tf_py_test( srcs = ["server_lib_same_variables_no_clear_test.py"], grpc_enabled = True, python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -708,7 +703,6 @@ tf_py_test( srcs = ["server_lib_sparse_job_test.py"], grpc_enabled = True, python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -735,7 +729,6 @@ cuda_py_test( "no_oss", # Test flaky due to port collisions. "oss_serial", ], - tfrt_enabled = True, deps = [ ":device_setter", "//tensorflow/python:client_testlib", @@ -762,7 +755,6 @@ tf_py_test( "notsan", # data race due to b/62910646 "oss_serial", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -840,7 +832,6 @@ cuda_py_test( "checkpoint_management_test.py", ], python_version = "PY3", - tfrt_enabled = True, deps = [ ":checkpoint_management", ":saver", @@ -962,7 +953,6 @@ tf_py_test( "noasan", # http://b/30379628 "notsan", # http://b/30379628 ], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client", @@ -982,7 +972,6 @@ tf_py_test( "noasan", # http://b/30782289 "notsan", # http://b/30782289 ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -1050,7 +1039,6 @@ tf_py_test( grpc_enabled = True, python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":checkpoint_management", ":saver", @@ -1106,7 +1094,6 @@ tf_py_test( size = "small", srcs = ["training_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":training_util", "//tensorflow/python:client_testlib", @@ -1121,7 +1108,6 @@ cuda_py_test( size = "medium", srcs = ["adam_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":adam", "//tensorflow/python:array_ops", @@ -1150,7 +1136,6 @@ cuda_py_test( "no_windows", # b/139083295: bfloat16 tests fail on Windows "notsan", ], - tfrt_enabled = True, deps = [ ":moving_averages", ":saver", @@ -1264,7 +1249,6 @@ tf_py_test( "no_windows", "notsan", # intermittent races on a few percent of runs ], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client", @@ -1313,7 +1297,6 @@ tf_py_test( size = "small", srcs = ["checkpoint_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:checkpoint_ops_gen", "//tensorflow/python:client", @@ -1334,7 +1317,6 @@ tf_py_test( size = "medium", srcs = ["warm_starting_util_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1382,7 +1364,6 @@ tf_py_test( "no_pip", "notsan", # b/67945581 ], - tfrt_enabled = True, deps = [ ":checkpoint_management", ":monitored_session", @@ -1408,7 +1389,6 @@ tf_py_test( size = "medium", srcs = ["input_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/training/experimental/BUILD b/tensorflow/python/training/experimental/BUILD index afc4cd673db..0e437666e19 100644 --- a/tensorflow/python/training/experimental/BUILD +++ b/tensorflow/python/training/experimental/BUILD @@ -107,7 +107,6 @@ cuda_py_test( size = "small", srcs = ["mixed_precision_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":mixed_precision", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 542311c75d8..1bf8d4f542d 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -37,20 +37,31 @@ from tensorflow.python.util.tf_export import tf_export @six.add_metaclass(abc.ABCMeta) -@deprecation.deprecated_endpoints('train.experimental.LossScale') -@tf_export('mixed_precision.experimental.LossScale', - 'train.experimental.LossScale') +@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale', + 'train.experimental.LossScale') +@tf_export( + 'mixed_precision.experimental.LossScale', + 'train.experimental.LossScale', + v1=[ + 'mixed_precision.LossScale', + 'mixed_precision.experimental.LossScale', + 'train.experimental.LossScale' + ]) class LossScale(trackable.Trackable): - """Base class for all loss scales. + """Base class for all TF1 loss scales. + + WARNING: This class is deprecated and will be unexposed from the TF 2 + namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only + be accessible as `tf.compat.v1.mixed_precision.LossScale`. Additionally in + 2.5, you will no longer be able to pass a `LossScale` to a + `tf.keras.mixed_precision.Policy`. All the functionality in this class has + been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class + is no longer needed. This is an abstract base class, so you cannot instantiate it directly. Instead, use one of its concrete subclasses: - * `tf.mixed_precision.experimental.DynamicLossScale` (recommended) - * `tf.mixed_precision.experimental.FixedLossScale` - - It's recommended to use a loss scale with a - `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, as its easier than - using a loss scale directly. + * `tf.compat.v1.mixed_precision.DynamicLossScale` + * `tf.compat.v1.mixed_precision.FixedLossScale` Loss scaling is a process that multiplies the loss by a multiplier called the loss scale, and divides each gradient by the same multiplier. The pseudocode @@ -198,16 +209,35 @@ class LossScale(trackable.Trackable): return cls(**config) -@deprecation.deprecated_endpoints('train.experimental.FixedLossScale') -@tf_export('mixed_precision.experimental.FixedLossScale', - 'train.experimental.FixedLossScale') +@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale', + 'train.experimental.FixedLossScale') +@tf_export( + 'mixed_precision.experimental.FixedLossScale', + 'train.experimental.FixedLossScale', + v1=[ + 'mixed_precision.FixedLossScale', + 'mixed_precision.experimental.FixedLossScale', + 'train.experimental.FixedLossScale' + ]) class FixedLossScale(LossScale): """Loss scale with a fixed value. + WARNING: This class is deprecated and will be unexposed from the TF 2 + namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only + be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. Additionally + in 2.5, you will no longer be able to pass a `FixedLossScale` to a + `tf.keras.mixed_precision.Policy`. All the functionality in this class has + been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class + is no longer needed. + The loss scale is not updated for the lifetime of instances of this class. A given instance of this class always returns the same number when called. """ + @deprecation.deprecated( + None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' + 'LossScaleOptimizer now has all the functionality of ' + 'FixedLossScale') def __init__(self, loss_scale_value): """Creates the fixed loss scale. @@ -280,12 +310,28 @@ def _assign_if_finite(var, value): control_flow_ops.no_op) -@deprecation.deprecated_endpoints('train.experimental.DynamicLossScale') -@tf_export('mixed_precision.experimental.DynamicLossScale', - 'train.experimental.DynamicLossScale') +@deprecation.deprecated_endpoints( + 'mixed_precision.experimental.DynamicLossScale', + 'train.experimental.DynamicLossScale') +@tf_export( + 'mixed_precision.experimental.DynamicLossScale', + 'train.experimental.DynamicLossScale', + v1=[ + 'mixed_precision.DynamicLossScale', + 'mixed_precision.experimental.DynamicLossScale', + 'train.experimental.DynamicLossScale' + ]) class DynamicLossScale(LossScale): """Loss scale that dynamically adjusts itself. + WARNING: This class is deprecated and will be unexposed from the TF 2 + namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only + be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. Additionally + in 2.5, you will no longer be able to pass a `DynamicLossScale` to a + `tf.keras.mixed_precision.Policy`. All the functionality in this class has + been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class + is no longer needed. + Dynamic loss scaling works by adjusting the loss scale as training progresses. The goal is to keep the loss scale as high as possible without overflowing the gradients. As long as the gradients do not overflow, raising the loss scale @@ -299,6 +345,10 @@ class DynamicLossScale(LossScale): overflowing. """ + @deprecation.deprecated( + None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' + 'LossScaleOptimizer now has all the functionality of ' + 'DynamicLossScale') def __init__(self, initial_loss_scale=2 ** 15, # See docstring for why this is big. increment_period=2000, diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py index c07c8cca60a..0c63177132d 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -24,10 +24,14 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -@tf_export(v1=['train.experimental.MixedPrecisionLossScaleOptimizer']) +@deprecation.deprecated_endpoints( + 'train.experimental.MixedPrecisionLossScaleOptimizer') +@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer', + 'train.experimental.MixedPrecisionLossScaleOptimizer']) class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): """An optimizer that applies loss scaling. diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py index af0e27dd860..d1a423a5cc3 100644 --- a/tensorflow/python/training/experimental/mixed_precision.py +++ b/tensorflow/python/training/experimental/mixed_precision.py @@ -23,6 +23,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import optimizer from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 from tensorflow.python.training.experimental import mixed_precision_global_state +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -61,6 +62,12 @@ def _wrap_optimizer(opt, loss_scale, use_v1_behavior): 'tf.keras.optimizers.Optimizer, but got: %s' % opt) +@deprecation.deprecated( + '2020-11-30', + 'Use tf.keras.mixed_precision. There is a guide at ' + 'https://www.tensorflow.org/guide/mixed_precision. Alternatively, ' + '`tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite` can ' + 'be used, but this is not recommended for TF2 code.') @tf_export('train.experimental.enable_mixed_precision_graph_rewrite', v1=[]) def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'): """Enable mixed precision via a graph rewrite. @@ -206,7 +213,10 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'): use_v1_behavior=False) -@tf_export(v1=['train.experimental.enable_mixed_precision_graph_rewrite']) +@deprecation.deprecated_endpoints( + 'train.experimental.enable_mixed_precision_graph_rewrite') +@tf_export(v1=['mixed_precision.enable_mixed_precision_graph_rewrite', + 'train.experimental.enable_mixed_precision_graph_rewrite']) def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'): """Enable mixed precision via a graph rewrite. @@ -348,6 +358,12 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale, return opt +@deprecation.deprecated( + '2020-11-30', + 'Use tf.keras.mixed_precision. There is a guide at ' + 'https://www.tensorflow.org/guide/mixed_precision. Alternatively, ' + '`tf.compat.v1.mixed_precision.disable_mixed_precision_graph_rewrite` can ' + 'be used, but this is not recommended for TF2 code.') @tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[]) def disable_mixed_precision_graph_rewrite(): """Disables the mixed precision graph rewrite. @@ -372,7 +388,10 @@ def disable_mixed_precision_graph_rewrite(): mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False -@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite']) +@deprecation.deprecated_endpoints( + 'train.experimental.disable_mixed_precision_graph_rewrite') +@tf_export(v1=['mixed_precision.disable_mixed_precision_graph_rewrite', + 'train.experimental.disable_mixed_precision_graph_rewrite']) def disable_mixed_precision_graph_rewrite_v1(): """Disables the mixed precision graph rewrite. diff --git a/tensorflow/python/training/experimental/mixed_precision_global_state.py b/tensorflow/python/training/experimental/mixed_precision_global_state.py index 6df4fdbe593..4bf7916d983 100644 --- a/tensorflow/python/training/experimental/mixed_precision_global_state.py +++ b/tensorflow/python/training/experimental/mixed_precision_global_state.py @@ -33,7 +33,7 @@ mixed_precision_graph_rewrite_is_enabled = False # Session has already been created. non_mixed_precision_session_created = False -# Whether the global tf.keras.mixed_precision.experimental.Policy uses mixed -# precision. Used to raise an error message if both a mixed Policy and the graph -# rewrite are used at the same time. +# Whether the global tf.keras.mixed_precision.Policy uses mixed precision. Used +# to raise an error message if both a mixed Policy and the graph rewrite are +# used at the same time. using_mixed_precision_policy = False diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD index 12940840309..7a65e3a77e8 100644 --- a/tensorflow/python/training/saving/BUILD +++ b/tensorflow/python/training/saving/BUILD @@ -39,6 +39,9 @@ cuda_py_test( srcs = [ "functional_saver_test.py", ], + tags = [ + "no_windows", # TODO(b/171350346) + ], deps = [ ":checkpoint_options", ":functional_saver", diff --git a/tensorflow/python/training/tracking/BUILD b/tensorflow/python/training/tracking/BUILD index 6001dc2cbbe..370b78c84f5 100644 --- a/tensorflow/python/training/tracking/BUILD +++ b/tensorflow/python/training/tracking/BUILD @@ -57,7 +57,6 @@ py_library( tf_py_test( name = "tracking_test", srcs = ["tracking_test.py"], - tfrt_enabled = True, deps = [ ":base", ":tracking", @@ -159,7 +158,6 @@ tf_py_test( name = "util_test", srcs = ["util_test.py"], tags = ["notsan"], # b/74395663 - tfrt_enabled = True, deps = [ ":base", ":graph_view", @@ -200,7 +198,6 @@ tf_py_test( tags = [ "notsan", # b/74395663 ], - tfrt_enabled = True, deps = [ ":tracking", ":util", @@ -243,7 +240,6 @@ tf_py_test( tf_py_test( name = "benchmarks_test", srcs = ["benchmarks_test.py"], - tfrt_enabled = True, deps = [ ":util", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 5e822f87e8c..e634a2c67cf 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -29,6 +29,7 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_stack +from tensorflow.tools.docs import doc_controls # Allow deprecation warnings to be silenced temporarily with a context manager. @@ -305,8 +306,23 @@ def deprecated(date, instructions, warn_once=True): """ _validate_deprecation_args(date, instructions) - def deprecated_wrapper(func): + def deprecated_wrapper(func_or_class): """Deprecation wrapper.""" + if isinstance(func_or_class, type): + # If a class is deprecated, you actually want to wrap the constructor. + cls = func_or_class + if cls.__new__ is object.__new__: + func = cls.__init__ + constructor_name = '__init__' + else: + func = cls.__new__ + constructor_name = '__new__' + + else: + cls = None + constructor_name = None + func = func_or_class + decorator_utils.validate_callable(func, 'deprecated') @functools.wraps(func) def new_func(*args, **kwargs): # pylint: disable=missing-docstring @@ -322,10 +338,25 @@ def deprecated(date, instructions, warn_once=True): 'in a future version' if date is None else ('after %s' % date), instructions) return func(*args, **kwargs) - return tf_decorator.make_decorator( + + doc_controls.set_deprecated(new_func) + new_func = tf_decorator.make_decorator( func, new_func, 'deprecated', _add_deprecated_function_notice_to_docstring(func.__doc__, date, instructions)) + + if cls is None: + return new_func + else: + # Insert the wrapped function as the constructor + setattr(cls, constructor_name, new_func) + + # And update the docstring of the class. + cls.__doc__ = _add_deprecated_function_notice_to_docstring( + cls.__doc__, date, instructions) + + return cls + return deprecated_wrapper diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 20c0846cfb8..a8babf3b011 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -19,6 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import enum + from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -95,6 +98,72 @@ class DeprecationTest(test.TestCase): _fn() self.assertEqual(1, mock_warning.call_count) + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_init_class(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyClass(): + """A test class.""" + + def __init__(self, a): + pass + + MyClass("") + self.assertEqual(1, mock_warning.call_count) + MyClass("") + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyClass.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_new_class(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyStr(str): + + def __new__(cls, value): + return str.__new__(cls, value) + + MyStr("abc") + self.assertEqual(1, mock_warning.call_count) + MyStr("abc") + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyStr.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_enum(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyEnum(enum.Enum): + a = 1 + b = 2 + + self.assertIs(MyEnum(1), MyEnum.a) + self.assertEqual(1, mock_warning.call_count) + self.assertIs(MyEnum(2), MyEnum.b) + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyEnum.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_namedtuple(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + mytuple = deprecation.deprecated( + date, instructions, warn_once=True)( + collections.namedtuple("my_tuple", ["field1", "field2"])) + + mytuple(1, 2) + self.assertEqual(1, mock_warning.call_count) + mytuple(3, 4) + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", mytuple.__doc__) + @test.mock.patch.object(logging, "warning", autospec=True) def test_silence(self, mock_warning): date = "2016-07-04" diff --git a/tensorflow/python/util/keras_deps.py b/tensorflow/python/util/keras_deps.py new file mode 100644 index 00000000000..3504d499769 --- /dev/null +++ b/tensorflow/python/util/keras_deps.py @@ -0,0 +1,44 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Interface that provides access to Keras dependencies. + +This library is a common interface that contains Keras functions needed by +TensorFlow and TensorFlow Lite and is required as per the dependency inversion +principle (https://en.wikipedia.org/wiki/Dependency_inversion_principle). As per +this principle, high-level modules (eg: TensorFlow and TensorFlow Lite) should +not depend on low-level modules (eg: Keras) and instead both should depend on a +common interface such as this file. +""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +_KERAS_CALL_CONTEXT_FUNCTION = None + + +def register_call_context_function(func): + global _KERAS_CALL_CONTEXT_FUNCTION + # TODO(scottzhu): Disable duplicated inject once keras is moved to + # third_party/py/keras. + _KERAS_CALL_CONTEXT_FUNCTION = func + + +def get_call_context_function(): + global _KERAS_CALL_CONTEXT_FUNCTION + return _KERAS_CALL_CONTEXT_FUNCTION diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index dfdd08d1501..b4f4c63bc96 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -30,7 +30,7 @@ class _ObjectIdentityWrapper(object): _ListWrapper objects to object-identity collections. """ - __slots__ = ["_wrapped"] + __slots__ = ["_wrapped", "__weakref__"] def __init__(self, wrapped): self._wrapped = wrapped @@ -72,6 +72,8 @@ class _ObjectIdentityWrapper(object): class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): + __slots__ = () + def __init__(self, wrapped): super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) @@ -99,6 +101,8 @@ class Reference(_ObjectIdentityWrapper): ``` """ + __slots__ = () + # Disabling super class' unwrapped field. unwrapped = property() @@ -153,6 +157,8 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping): class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): """Like weakref.WeakKeyDictionary, but compares objects with "is".""" + __slots__ = ["__weakref__"] + def _wrap_key(self, key): return _WeakObjectIdentityWrapper(key) @@ -173,7 +179,7 @@ class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): class ObjectIdentitySet(collections_abc.MutableSet): """Like the built-in set, but compares objects with "is".""" - __slots__ = ["_storage"] + __slots__ = ["_storage", "__weakref__"] def __init__(self, *args): self._storage = set(self._wrap_key(obj) for obj in list(*args)) @@ -221,6 +227,8 @@ class ObjectIdentitySet(collections_abc.MutableSet): class ObjectIdentityWeakSet(ObjectIdentitySet): """Like weakref.WeakSet, but compares objects with "is".""" + __slots__ = () + def _wrap_key(self, key): return _WeakObjectIdentityWrapper(key) diff --git a/tensorflow/security/README.md b/tensorflow/security/README.md index 9f02de4153d..27e24b6392b 100644 --- a/tensorflow/security/README.md +++ b/tensorflow/security/README.md @@ -10,6 +10,8 @@ in [SECURITY.md](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.m | Advisory Number | Type | Versions affected | Reported by | Additional Information | |-----------------|--------------------|:-----------------:|-----------------------|-----------------------------| +| [TFSA-2020-028](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-028.md) | Float cast overflow undefined behavior | <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42129) | +| [TFSA-2020-027](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-027.md) | Segfault in `tf.quantization.quantize_and_dequantize`| <= 2.3 | (Reported on GitHub) | [issue report](https://github.com/tensorflow/tensorflow/issues/42105) | | [TFSA-2020-026](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-026.md) | Segfault in `tf.raw_ops.Switch` in eager mode | 2.2.0, 2.3.0 | Aivul Team from Qihoo 360 | | | [TFSA-2020-025](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-025.md) | Undefined behavior in `dlpack.to_dlpack` | 2.2.0, 2.3.0 | Aivul Team from Qihoo 360 | | | [TFSA-2020-024](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2020-024.md) | Memory leak in `dlpack.to_dlpack` | 2.2.0, 2.3.0 | Aivul Team from Qihoo 360 | | diff --git a/tensorflow/security/advisory/tfsa-2020-027.md b/tensorflow/security/advisory/tfsa-2020-027.md new file mode 100644 index 00000000000..e2c87c97953 --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2020-027.md @@ -0,0 +1,52 @@ +## TFSA-2020-027: Segfault in `tf.quantization.quantize_and_dequantize` + +### CVE Number +CVE-2020-15265 + +### Impact +An attacker can pass an invalid `axis` value to +`tf.quantization.quantize_and_dequantize`: + +```python +tf.quantization.quantize_and_dequantize( + input=[2.5, 2.5], input_min=[0,0], input_max=[1,1], axis=10) +``` + +This results in accessing [a dimension outside the rank of the input +tensor](https://github.com/tensorflow/tensorflow/blob/0225022b725993bfc19b87a02a2faaad9a53bc17/tensorflow/core/kernels/quantize_and_dequantize_op.cc#L74) +in the C++ kernel implementation: +```cc +const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); +``` + +However, [`dim_size` only does a +`DCHECK`](https://github.com/tensorflow/tensorflow/blob/0225022b725993bfc19b87a02a2faaad9a53bc17/tensorflow/core/framework/tensor_shape.cc#L292-L307) +to validate the argument and then uses it to access the corresponding element of +an array: +```cc +int64 TensorShapeBase::dim_size(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + DoStuffWith(dims_[d]); +} +``` + +Since in normal builds, `DCHECK`-like macros are no-ops, this results in +segfault and access out of bounds of the array. + +### Patches + +We have patched the issue in +[eccb7ec454e6617738554a255d77f08e60ee0808](https://github.com/tensorflow/tensorflow/commit/eccb7ec454e6617738554a255d77f08e60ee0808) +and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly +packages after this commit will also have the issue resolved. + +### For more information +Please consult [our security +guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for +more information regarding the security model and how to contact us with issues +and questions. + +### Attribution +This vulnerability has been reported in +[#42105](https://github.com/tensorflow/issues/42105). diff --git a/tensorflow/security/advisory/tfsa-2020-028.md b/tensorflow/security/advisory/tfsa-2020-028.md new file mode 100644 index 00000000000..a69df53a77e --- /dev/null +++ b/tensorflow/security/advisory/tfsa-2020-028.md @@ -0,0 +1,27 @@ +## TFSA-2020-028: Float cast overflow undefined behavior + +### CVE Number +CVE-2020-15266 + +### Impact +When the `boxes` argument of `tf.image.crop_and_resize` has a very large value, +the CPU kernel implementation receives it as a C++ `nan` floating point value. +Attempting to operate on this is undefined behavior which later produces a +segmentation fault. + +### Patches + +We have patched the issue in +[c0319231333f0f16e1cc75ec83660b01fedd4182](https://github.com/tensorflow/tensorflow/commit/c0319231333f0f16e1cc75ec83660b01fedd4182) +and will release TensorFlow 2.4.0 containing the patch. TensorFlow nightly +packages after this commit will also have the issue resolved. + +### For more information +Please consult [our security +guide](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) for +more information regarding the security model and how to contact us with issues +and questions. + +### Attribution +This vulnerability has been reported in +[#42129](https://github.com/tensorflow/issues/42129). diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index 84f592439ea..d6333d54c1e 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -61,7 +61,6 @@ cc_library( "blas.h", "device_description.h", "device_options.h", - "dnn.h", "event.cc", "fft.h", "kernel_cache_config.h", @@ -103,7 +102,6 @@ cc_library( cc_library( name = "kernel", srcs = [ - "dnn.h", "fft.h", "kernel.cc", "plugin.h", @@ -300,6 +298,7 @@ cc_library( name = "host_or_device_scalar", hdrs = ["host_or_device_scalar.h"], deps = [ + ":data_type", ":device_memory", "//tensorflow/stream_executor/platform", ], @@ -331,7 +330,6 @@ cc_library( ], hdrs = [ "blas.h", - "dnn.h", "executor_cache.h", "fft.h", "kernel.h", @@ -423,11 +421,21 @@ tf_proto_library( make_default_target_header_only = True, ) +cc_library( + name = "data_type", + hdrs = ["data_type.h"], + deps = [ + ":dnn_proto_cc", + "//tensorflow/stream_executor/platform", + ], +) + cc_library( name = "dnn", srcs = ["dnn.cc"], hdrs = ["dnn.h"], deps = [ + ":data_type", ":device_memory", ":dnn_proto_cc", ":stream_executor_headers", @@ -445,7 +453,6 @@ cc_library( cc_library( name = "stream_executor_internal", srcs = [ - "dnn.h", "stream_executor_internal.cc", ], hdrs = [ @@ -474,7 +481,6 @@ cc_library( name = "stream_executor_pimpl_header", hdrs = [ "device_description.h", - "dnn.h", "kernel.h", "kernel_cache_config.h", "stream_executor_pimpl.h", diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index f499b3003d0..ca597595969 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) { return os << ComputationTypeString(ty); } +string DataTypeString(DataType ty) { + switch (ty) { + case DataType::kHalf: + return "f16"; + case DataType::kFloat: + return "f32"; + case DataType::kDouble: + return "f64"; + case DataType::kInt8: + return "i8"; + case DataType::kInt32: + return "i32"; + case DataType::kComplexFloat: + return "complex f32"; + case DataType::kComplexDouble: + return "complex f64"; + default: + LOG(FATAL) << "Unknown DataType " << static_cast(ty); + } +} + +std::ostream& operator<<(std::ostream& os, DataType ty) { + return os << DataTypeString(ty); +} + } // namespace blas } // namespace stream_executor diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 5018d487ed1..20776b8416d 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -43,7 +43,7 @@ limitations under the License. #include #include -#include "tensorflow/stream_executor/host_or_device_scalar.h" +#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/port.h" @@ -60,6 +60,9 @@ class ScratchAllocator; template class DeviceMemory; +template +class HostOrDeviceScalar; + namespace blas { // Specifies whether the input matrix will be transposed or @@ -101,6 +104,18 @@ enum class ComputationType { kI32, // 32-bit integer kComplexF32, // Complex number comprised of two f32s. kComplexF64, // Complex number comprised of two f64s. + // The below values are only supported for BlasLt routines (both real and + // complex). They use float32 for accumulation but round the input mantissas + // to a smaller number of bits. + kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa + kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa +}; + +enum class Epilogue { + kDefault = 1, // No special postprocessing + kReLU = 2, // Apply ReLU func point-wise to the results + kBias = 4, // Add broadcasted bias vector to the results + kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform }; // Converts a ComputationType to a string. @@ -108,6 +123,21 @@ std::string ComputationTypeString(ComputationType ty); std::ostream &operator<<(std::ostream &os, ComputationType ty); +using dnn::DataType; +using dnn::ToDataType; + +// Describes the type of pointers for the scaling factors alpha and beta in +// blaslt routines. +enum class PointerMode { + kHost, + kDevice, +}; + +// Converts a ComputationType to a string. +string DataTypeString(DataType ty); + +std::ostream &operator<<(std::ostream &os, DataType ty); + // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; @@ -163,6 +193,44 @@ class AlgorithmConfig { AlgorithmType algorithm_; }; +struct IBlasLtMatmulPlan { + // Returns the data type of the A and B (input) matrices. + virtual DataType ab_type() const = 0; + // Returns the data type of the C (input/output) matrix. + virtual DataType c_type() const = 0; + virtual ~IBlasLtMatmulPlan() {} +}; + +struct IBlasLtMatmulAlgorithm { + virtual ~IBlasLtMatmulAlgorithm() {} + // Returns the index of the algorithm within the list returned by + // GetBlasLtMatmulAlgorithms. + virtual AlgorithmType index() const = 0; + // Returns the workspace size required by the algorithm in bytes. + virtual size_t workspace_size() const = 0; +}; + +// Parameters for the CreateBlasLtMatmulPlan method. +struct BlasLtMatmulPlanParams { + DataType ab_type; + DataType c_type; + ComputationType computation_type; + PointerMode pointer_mode; + Epilogue epilogue; + Transpose transa; + Transpose transb; + uint64 m; + uint64 n; + uint64 k; + int64 lda; + int64 ldb; + int64 ldc; + int batch_count = 1; + int64 stride_a = 0; + int64 stride_b = 0; + int64 stride_c = 0; +}; + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -1383,6 +1451,71 @@ class BlasSupport { const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; + // Creates a backend-specific plan object for a blaslt matmul operation, which + // can then be passed to DoBlasLtMatmul(). When possible, plans should be + // created once and reused for multiple calls to DoBlasLtMatmul(). + virtual port::StatusOr> + CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) = 0; + + // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to an + // internal heuristic. The first returned algorithm can be used as the default + // algorithm if no autotuning is to be performed. + virtual port::StatusOr< + std::vector>> + GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count) = 0; + + // Executes a blaslt matmul operation on the stream. If output_profile_result + // is not nullptr, the operation is profiled, error messages are + // suppressed, and output_profile_result->algorithm() is set to + // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when + // creating the plan, the bias argument here must refer to a valid device + // vector of length equal to the number of rows in matrix c. If epilogue was + // set to any other value then the bias argument here must be null. The bias + // vector is broadcast across the batch dimension. + // Note that the data types of a and b (c and bias) must match the ab_type + // (c_type) with which the plan was created, and the data types of alpha and + // beta must match the data type of c. + virtual bool DoBlasLtMatmul( + Stream *stream, const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const HostOrDeviceScalar &beta, + DeviceMemoryBase c, ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, + blas::ProfileResult *output_profile_result) = 0; + + template + bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, + const DeviceMemory &a, + const DeviceMemory &b, + const HostOrDeviceScalar &beta, + DeviceMemory *c, + ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, + const DeviceMemory &bias = {}, + blas::ProfileResult *output_profile_result = nullptr) { + constexpr blas::DataType ab_type = blas::ToDataType::value; + if (ab_type != plan->ab_type()) { + VLOG(2) << "DoBlasLtMatmul returning false because a and b type does " + "not match plan: expected " + << plan->ab_type() << ", got " << ab_type; + return false; + } + constexpr blas::DataType c_type = blas::ToDataType::value; + if (c_type != plan->c_type()) { + VLOG(2) << "DoBlasLtMatmul returning false because c type does " + "not match plan: expected " + << plan->c_type() << ", got " << c_type; + return false; + } + return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c, + scratch_allocator, algorithm, bias, + output_profile_result); + } + virtual port::Status GetVersion(std::string *version) = 0; protected: @@ -2196,6 +2329,19 @@ class BlasSupport { uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ + port::StatusOr> \ + CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) override; \ + port::StatusOr>> \ + GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \ + size_t max_workspace_size, \ + int max_algorithm_count) override; \ + bool DoBlasLtMatmul( \ + Stream *stream, const blas::IBlasLtMatmulPlan *plan, \ + const HostOrDeviceScalar &alpha, DeviceMemoryBase a, \ + DeviceMemoryBase b, const HostOrDeviceScalar &beta, \ + DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \ + blas::ProfileResult *output_profile_result) override; \ port::Status GetVersion(std::string *version) override; } // namespace blas diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index caab6223fda..0ee227d51f2 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -142,7 +142,9 @@ cc_library( tf_cuda_cc_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171512140): re-enable. + ], deps = [ "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -154,7 +156,9 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "memcpy_test", srcs = ["memcpy_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171512140): re-enable. + ], deps = [ "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -251,6 +255,31 @@ alias( visibility = ["//visibility:public"], ) +cc_library( + name = "cublas_lt_stub", + srcs = if_cuda_is_configured(["cublasLt_stub.cc"]), + textual_hdrs = glob(["cublasLt_*.inc"]), + deps = if_cuda_is_configured([ + # LINT.IfChange + "@local_config_cuda//cuda:cublas_headers", + # LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers) + "@local_config_cuda//cuda:cuda_headers", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform:dso_loader", + ]), +) + +cc_library(name = "empty_lib") + +alias( + name = "cublas_lt_lib", + actual = select({ + "//tensorflow:oss": ":cublas_lt_stub", + "//conditions:default": ":empty_lib", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "cublas_plugin", srcs = if_cuda_is_configured(["cuda_blas.cc"]), @@ -258,6 +287,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cublas_lib", + ":cublas_lt_lib", ":cuda_activation", ":cuda_gpu_executor", ":cuda_platform_id", @@ -632,7 +662,9 @@ cc_library( tf_cuda_cc_test( name = "redzone_allocator_test", srcs = ["redzone_allocator_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "no_cuda_asan", # TODO(b/171512140): re-enable. + ], deps = [ ":cuda_activation", ":cuda_gpu_executor", diff --git a/tensorflow/stream_executor/cuda/cublasLt_11_0.inc b/tensorflow/stream_executor/cuda/cublasLt_11_0.inc new file mode 100644 index 00000000000..5645753c56b --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublasLt_11_0.inc @@ -0,0 +1,390 @@ +// Auto-generated, do not edit. + +extern "C" { + +cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t *lightHandle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t *); + static auto func_ptr = LoadSymbol("cublasLtCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle); +} + +cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t); + static auto func_ptr = LoadSymbol("cublasLtDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle); +} + +size_t CUBLASWINAPI cublasLtGetVersion(void) { + using FuncPtr = size_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasLtGetVersion"); + if (!func_ptr) return 0; + return func_ptr(); +} + +size_t CUBLASWINAPI cublasLtGetCudartVersion(void) { + using FuncPtr = size_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasLtGetCudartVersion"); + if (!func_ptr) return 0; + return func_ptr(); +} + +cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cublasLtGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmul( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, /* host or device pointer */ + const void *A, cublasLtMatrixLayout_t Adesc, const void *B, + cublasLtMatrixLayout_t Bdesc, const void *beta, /* host or device pointer */ + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo, + void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, + cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t, + const void *, const void *, cublasLtMatrixLayout_t, void *, + cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t, + cudaStream_t); + static auto func_ptr = LoadSymbol("cublasLtMatmul"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, + Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes, + stream); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform( + cublasLtHandle_t lightHandle, cublasLtMatrixTransformDesc_t transformDesc, + const void *alpha, /* host or device pointer */ + const void *A, cublasLtMatrixLayout_t Adesc, + const void *beta, /* host or device pointer */ + const void *B, cublasLtMatrixLayout_t Bdesc, void *C, + cublasLtMatrixLayout_t Cdesc, cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *, + const void *, cublasLtMatrixLayout_t, const void *, const void *, + cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransform"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc, + C, Cdesc, stream); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( // + cublasLtMatrixLayout_t matLayout, size_t size, cudaDataType type, + uint64_t rows, uint64_t cols, int64_t ld) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t, + int64_t); + static auto func_ptr = + LoadSymbol("cublasLtMatrixLayoutInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, size, type, rows, cols, ld); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( // + cublasLtMatrixLayout_t *matLayout, cudaDataType type, uint64_t rows, + uint64_t cols, int64_t ld) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, type, rows, cols, ld); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixLayout_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( // + cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr, + const void *buf, size_t sizeInBytes) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *, + size_t); + static auto func_ptr = + LoadSymbol("cublasLtMatrixLayoutSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( // + cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr, + void *buf, size_t sizeInBytes, size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t, + size_t *); + static auto func_ptr = + LoadSymbol("cublasLtMatrixLayoutGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( // + cublasLtMatmulDesc_t matmulDesc, size_t size, + cublasComputeType_t computeType, cudaDataType_t scaleType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, size, computeType, scaleType); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate( + cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType, + cudaDataType_t scaleType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, computeType, scaleType); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulDesc_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( // + cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr, + const void *buf, size_t sizeInBytes) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, + size_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( // + cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr, + void *buf, size_t sizeInBytes, size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t, + size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal( + cublasLtMatrixTransformDesc_t transformDesc, size_t size, + cudaDataType scaleType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t, + size_t, cudaDataType); + static auto func_ptr = + LoadSymbol("cublasLtMatrixTransformDescInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, size, scaleType); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate( + cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtMatrixTransformDesc_t *, cudaDataType); + static auto func_ptr = + LoadSymbol("cublasLtMatrixTransformDescCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, scaleType); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy( + cublasLtMatrixTransformDesc_t transformDesc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t); + static auto func_ptr = + LoadSymbol("cublasLtMatrixTransformDescDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, + const void *, size_t); + static auto func_ptr = + LoadSymbol("cublasLtMatrixTransformDescSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, void *buf, size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, + void *, size_t, size_t *); + static auto func_ptr = + LoadSymbol("cublasLtMatrixTransformDescGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal( + cublasLtMatmulPreference_t pref, size_t size) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t); + static auto func_ptr = + LoadSymbol("cublasLtMatmulPreferenceInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, size); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( // + cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr, + const void *buf, size_t sizeInBytes) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, + const void *, size_t); + static auto func_ptr = + LoadSymbol("cublasLtMatmulPreferenceSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( // + cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr, + void *buf, size_t sizeInBytes, size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *, + size_t, size_t *); + static auto func_ptr = + LoadSymbol("cublasLtMatmulPreferenceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc, + cublasLtMatmulPreference_t preference, int requestedAlgoCount, + cublasLtMatmulHeuristicResult_t heuristicResultsArray[], + int *returnAlgoCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, + cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, + cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t[], + int *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoGetHeuristic"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, + preference, requestedAlgoCount, heuristicResultsArray, + returnAlgoCount); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds( + cublasLtHandle_t lightHandle, cublasComputeType_t computeType, + cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype, + cudaDataType_t Ctype, cudaDataType_t Dtype, int requestedAlgoCount, + int algoIdsArray[], int *returnAlgoCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, + cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int[], int *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoGetIds"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, + Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit( + cublasLtHandle_t lightHandle, cublasComputeType_t computeType, + cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype, + cudaDataType_t Ctype, cudaDataType_t Dtype, int algoId, + cublasLtMatmulAlgo_t *algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, + cudaDataType_t, cudaDataType_t, cudaDataType_t, int, + cublasLtMatmulAlgo_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, + Dtype, algoId, algo); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( // + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo + cublasLtMatmulHeuristicResult_t *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( // + cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, + cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, + const cublasLtMatmulAlgo_t *, ///< may point to result->algo + cublasLtMatmulHeuristicResult_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoCheck"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo, + result); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute( + const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoCapAttributes_t attr, + void *buf, size_t sizeInBytes, size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *, + size_t, size_t *); + static auto func_ptr = + LoadSymbol("cublasLtMatmulAlgoCapGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute( + cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr, + const void *buf, size_t sizeInBytes) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, + const void *, size_t); + static auto func_ptr = + LoadSymbol("cublasLtMatmulAlgoConfigSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute( + const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr, + void *buf, size_t sizeInBytes, size_t *sizeWritten) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, + void *, size_t, size_t *); + static auto func_ptr = + LoadSymbol("cublasLtMatmulAlgoConfigGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cublasLt_stub.cc b/tensorflow/stream_executor/cuda/cublasLt_stub.cc new file mode 100644 index 00000000000..aae8a94285b --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublasLt_stub.cc @@ -0,0 +1,59 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "third_party/gpus/cuda/include/cublasLt.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" + +// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. + +namespace { +// Returns DSO handle or null if loading the DSO fails. +void* GetDsoHandle() { +#ifdef PLATFORM_GOOGLE + return nullptr; +#else + static auto handle = []() -> void* { + auto handle_or = + stream_executor::internal::DsoLoader::GetCublasLtDsoHandle(); + if (!handle_or.ok()) return nullptr; + return handle_or.ValueOrDie(); + }(); + return handle; +#endif +} + +template +T LoadSymbol(const char* symbol_name) { + void* symbol = nullptr; + if (auto handle = GetDsoHandle()) { + stream_executor::port::Env::Default() + ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + .IgnoreError(); + } + return reinterpret_cast(symbol); +} + +void LogFatalSymbolNotFound(const char* symbol_name) { + LOG(FATAL) << symbol_name << " symbol not found."; +} + +cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; } +} // namespace + +// We only use cublasLt from CUDA 11.0 onward. +#if CUDA_VERSION >= 11000 +#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc" +#endif diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 1e2fd0ef6b4..7fb94c7f543 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" @@ -226,17 +227,38 @@ bool CUDABlas::Init() { return false; } +#if CUDA_VERSION >= 11000 + ret = cublasLtCreate(&blasLt_); + if (ret != CUBLAS_STATUS_SUCCESS) { + LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret); + return false; + } +#endif // CUDA_VERSION >= 11000 + return true; } CUDABlas::CUDABlas(gpu::GpuExecutor *parent) - : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} + : parent_(CHECK_NOTNULL(parent)), + blas_(nullptr) +#if CUDA_VERSION >= 11000 + , + blasLt_(nullptr) +#endif +{ +} CUDABlas::~CUDABlas() { if (blas_ != nullptr) { gpu::ScopedActivateExecutorContext sac{parent_}; cublasDestroy(blas_); } +#if CUDA_VERSION >= 11000 + if (blasLt_ != nullptr) { + gpu::ScopedActivateExecutorContext sac{parent_}; + cublasLtDestroy(blasLt_); + } +#endif } bool CUDABlas::SetStream(Stream *stream) { @@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) { return true; } +cudaStream_t CUDABlas::CUDAStream(Stream *stream) { + CHECK(stream != nullptr); + CHECK(AsGpuStreamValue(stream) != nullptr); + gpu::ScopedActivateExecutorContext sac{parent_}; + return AsGpuStreamValue(stream); +} + namespace { // Helper functions transforming blas arguments into cuBLAS arguments. @@ -381,8 +410,122 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) { return CUDA_C_32F; case blas::ComputationType::kComplexF64: return CUDA_C_64F; + case blas::ComputationType::kTF32AsF32: // fall-through + case blas::ComputationType::kBF16AsF32: + // These cases are currently only supported in the blasLt routines, which + // use CUBLASComputationType() instead. + LOG(FATAL) << "Invalid value of blas::ComputationType."; } } + +#if CUDA_VERSION >= 11000 +cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) { + switch (ty) { + case blas::ComputationType::kF16: + return CUBLAS_COMPUTE_16F; + case blas::ComputationType::kF32: // fall-through + case blas::ComputationType::kComplexF32: + return CUBLAS_COMPUTE_32F; + case blas::ComputationType::kF64: // fall-through + case blas::ComputationType::kComplexF64: + return CUBLAS_COMPUTE_64F; + case blas::ComputationType::kI32: + return CUBLAS_COMPUTE_32I; + case blas::ComputationType::kTF32AsF32: + return CUBLAS_COMPUTE_32F_FAST_TF32; + case blas::ComputationType::kBF16AsF32: + return CUBLAS_COMPUTE_32F_FAST_16BF; + } +} +#endif // CUDA_VERSION >= 11000 + +blas::DataType GetScaleType(blas::DataType data_type, + blas::ComputationType compute_type) { + bool is_complex = data_type == blas::DataType::kComplexFloat || + data_type == blas::DataType::kComplexDouble; + switch (compute_type) { + case blas::ComputationType::kF16: + return blas::DataType::kHalf; + case blas::ComputationType::kF32: // fall-through + case blas::ComputationType::kComplexF32: // fall-through + case blas::ComputationType::kTF32AsF32: // fall-through + case blas::ComputationType::kBF16AsF32: + return is_complex ? blas::DataType::kComplexFloat + : blas::DataType::kFloat; + case blas::ComputationType::kF64: // fall-through + case blas::ComputationType::kComplexF64: + return is_complex ? blas::DataType::kComplexDouble + : blas::DataType::kDouble; + case blas::ComputationType::kI32: + return blas::DataType::kInt32; + } +} + +#if CUDA_VERSION >= 11000 +cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) { + switch (pointer_mode) { + case blas::PointerMode::kHost: + return CUBLASLT_POINTER_MODE_HOST; + case blas::PointerMode::kDevice: + return CUBLASLT_POINTER_MODE_DEVICE; + } +} +cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) { + switch (epilogue) { + case blas::Epilogue::kDefault: + return CUBLASLT_EPILOGUE_DEFAULT; + case blas::Epilogue::kReLU: + return CUBLASLT_EPILOGUE_RELU; + case blas::Epilogue::kBias: + return CUBLASLT_EPILOGUE_BIAS; + case blas::Epilogue::kBiasThenReLU: + return CUBLASLT_EPILOGUE_RELU_BIAS; + } +} +#endif // CUDA_VERSION >= 11000 + +cudaDataType_t GetCUDADataType(blas::DataType ty) { + switch (ty) { + case blas::DataType::kHalf: + return CUDA_R_16F; + case blas::DataType::kFloat: + return CUDA_R_32F; + case blas::DataType::kDouble: + return CUDA_R_64F; + case blas::DataType::kInt8: + return CUDA_R_8I; + case blas::DataType::kInt32: + return CUDA_R_32I; + case blas::DataType::kComplexFloat: + return CUDA_C_32F; + case blas::DataType::kComplexDouble: + return CUDA_C_64F; + default: + LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType"; + } +} + +int GetDataTypeSizeBytes(blas::DataType ty) { + switch (ty) { + case blas::DataType::kHalf: + return 2; + case blas::DataType::kFloat: + return 4; + case blas::DataType::kDouble: + return 8; + case blas::DataType::kInt8: + return 1; + case blas::DataType::kInt32: + return 4; + case blas::DataType::kComplexFloat: + return 8; + case blas::DataType::kComplexDouble: + return 16; + default: + LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes"; + } +} + } // namespace template @@ -2921,6 +3064,680 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, GpuComplex(GpuMemoryMutable(b)), ldb); } +// We only use cublasLt from CUDA 11.0 onward. +#if CUDA_VERSION >= 11000 + +namespace { + +template +inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle, + cublasLtMatrixLayoutAttribute_t attr, + const T &value) { + cublasStatus_t status = + cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T)); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr, + ", value=", value, ") failed: ", ToString(status))); + } + return port::Status::OK(); +} + +template +inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle, + cublasLtMatmulAlgoConfigAttributes_t attr, + const T &value) { + cublasStatus_t status = + cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T)); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr, + ", value=", value, ") failed: ", ToString(status))); + } + return port::Status::OK(); +} + +template +inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle, + cublasLtMatmulPreferenceAttributes_t attr, + const T &value) { + cublasStatus_t status = + cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value)); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr, + ", value=", value, ") failed: ", ToString(status))); + } + return port::Status::OK(); +} + +template +inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle, + cublasLtMatmulAlgoConfigAttributes_t attr, + T *value) { + auto mutable_handle = const_cast(handle); + size_t bytes_written = 0; + return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value, + sizeof(T), &bytes_written) == + CUBLAS_STATUS_SUCCESS && + bytes_written == sizeof(T); +} + +template +inline const T &ValueForStrCat(const T &value) { + return value; +} +template +inline absl::Hex ValueForStrCat(T *ptr) { + return absl::Hex(reinterpret_cast(ptr)); +} + +template +inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle, + cublasLtMatmulDescAttributes_t attr, + const T &value) { + cublasStatus_t status = + cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value)); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=", + ValueForStrCat(value), ") failed: ", ToString(status))); + } + return port::Status::OK(); +} + +struct MatmulDescDestroyer { + void operator()(cublasLtMatmulDesc_t matmul_desc) const { + cublasLtMatmulDescDestroy(matmul_desc); + } +}; +struct LayoutDestroyer { + void operator()(cublasLtMatrixLayout_t layout) const { + cublasLtMatrixLayoutDestroy(layout); + } +}; +struct MatmulPreferenceDestroyer { + void operator()(cublasLtMatmulPreference_t matmul_pref) const { + cublasLtMatmulPreferenceDestroy(matmul_pref); + } +}; +using UniqueOpDesc = + std::unique_ptr::type, + MatmulDescDestroyer>; +using UniqueLayoutDesc = + std::unique_ptr::type, + LayoutDestroyer>; +using UniqueMatmulPreference = + std::unique_ptr::type, + MatmulPreferenceDestroyer>; + +port::StatusOr CreateCublasLtOperationDesc( + blas::ComputationType computation_type, blas::DataType scale_type, + blas::PointerMode pointer_mode, blas::Epilogue epilogue, + blas::Transpose transa, blas::Transpose transb) { + cublasLtMatmulDesc_t desc; + cublasComputeType_t cublas_compute_type = + CUBLASComputationType(computation_type); + cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type); + cublasStatus_t status = + cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatmulDescCreate(computation_type=", + computation_type, ") failed: ", ToString(status))); + } + UniqueOpDesc unique_desc(desc); + SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + CUBLASPointerMode(pointer_mode))); + SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + CUBLASEpilogue(epilogue))); + SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, + CUDABlasTranspose(transa))); + SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, + CUDABlasTranspose(transb))); + return unique_desc; +} + +port::StatusOr CreateCublasLtLayoutDesc( + blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride, + int batch_count) { + cublasLtMatrixLayout_t desc; + cublasStatus_t status = cublasLtMatrixLayoutCreate( + &desc, GetCUDADataType(data_type), rows, cols, ld); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status))); + } + UniqueLayoutDesc unique_desc(desc); + SE_RETURN_IF_ERROR( + SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count)); + SE_RETURN_IF_ERROR(SetCublasLtAttr( + desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride)); + return unique_desc; +} + +// Helper function to allocate workspace. +port::Status AllocateWorkspace(void **workspace, + ScratchAllocator *scratch_allocator, + size_t num_bytes) { + SE_ASSIGN_OR_RETURN(DeviceMemory workspace_bytes, + scratch_allocator->AllocateBytes(num_bytes)); + *workspace = (void *)GpuMemoryMutable(&workspace_bytes); + return port::Status::OK(); +} + +template +blas::ComputationType ToComputationType(); +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF16; +} +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF32; +} +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF64; +} +template <> +blas::ComputationType ToComputationType>() { + return blas::ComputationType::kComplexF32; +} +template <> +blas::ComputationType ToComputationType>() { + return blas::ComputationType::kComplexF64; +} + +class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { + public: + port::Status init(const blas::BlasLtMatmulPlanParams &p) { + params_ = p; + scale_type_ = GetScaleType(p.c_type, p.computation_type); + SE_ASSIGN_OR_RETURN( + op_desc_, + CreateCublasLtOperationDesc( + p.computation_type, GetScaleType(p.c_type, p.computation_type), + p.pointer_mode, p.epilogue, p.transa, p.transb)); + uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; + uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; + uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; + uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; + SE_ASSIGN_OR_RETURN( + a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, + p.stride_a, capped_batch_count())); + SE_ASSIGN_OR_RETURN( + b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, + p.stride_b, capped_batch_count())); + SE_ASSIGN_OR_RETURN( + c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + capped_batch_count())); + SE_ASSIGN_OR_RETURN( + d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + capped_batch_count())); + remainder_batch_count_ = + p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0; + if (remainder_batch_count_) { + SE_ASSIGN_OR_RETURN( + a_remainder_desc_, + CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + b_remainder_desc_, + CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + c_remainder_desc_, + CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + d_remainder_desc_, + CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + remainder_batch_count_)); + } + return port::Status::OK(); + } + + cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } + cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } + cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); } + cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); } + cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); } + cublasLtMatrixLayout_t a_remainder_desc() const { + return a_remainder_desc_.get(); + } + cublasLtMatrixLayout_t b_remainder_desc() const { + return b_remainder_desc_.get(); + } + cublasLtMatrixLayout_t c_remainder_desc() const { + return c_remainder_desc_.get(); + } + cublasLtMatrixLayout_t d_remainder_desc() const { + return d_remainder_desc_.get(); + } + + const blas::BlasLtMatmulPlanParams ¶ms() const { return params_; } + blas::DataType scale_type() const { return scale_type_; } + blas::DataType ab_type() const override { return params_.ab_type; } + blas::DataType c_type() const override { return params_.c_type; } + int capped_batch_count() const { + return std::min(params_.batch_count, kMaxBatchCount); + } + int remainder_batch_count() const { return remainder_batch_count_; } + + // Note: Must be const to satisfy API. This is always called before the plan + // is executed, so the state change is not observed in subsequent executions. + bool SetBiasPointer(const void *bias) const; + + private: + // In some cases cublasLt does not support large batch sizes, so we need to + // split up such cases into multiple calls. + static constexpr const int kMaxBatchCount = 65535; + blas::BlasLtMatmulPlanParams params_; + blas::DataType scale_type_; + UniqueOpDesc op_desc_; + // These have batch count set to capped_batch_count(). + UniqueLayoutDesc a_desc_; + UniqueLayoutDesc b_desc_; + UniqueLayoutDesc c_desc_; + UniqueLayoutDesc d_desc_; + int remainder_batch_count_; + // These have batch count set to remainder_batch_count_, and are only created + // if params_.batch_count > kMaxBatchSize. + UniqueLayoutDesc a_remainder_desc_; + UniqueLayoutDesc b_remainder_desc_; + UniqueLayoutDesc c_remainder_desc_; + UniqueLayoutDesc d_remainder_desc_; +}; + +bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const { + return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias) + .ok(); +} + +class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm { + public: + CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index, + cublasLtMatmulAlgo_t algo, size_t workspace_size) + : index_(index), algo_(algo), workspace_size_(workspace_size) {} + + blas::AlgorithmType index() const override { return index_; } + + size_t workspace_size() const override { return workspace_size_; } + + const cublasLtMatmulAlgo_t *algo() const { return &algo_; } + + int algo_id() const { + int id; + GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id); + return id; + } + + private: + blas::AlgorithmType index_; + cublasLtMatmulAlgo_t algo_; + size_t workspace_size_; +}; + +port::StatusOr CreateCublasLtMatmulPreference( + const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) { + cublasLtMatmulPreference_t preference; + cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status(port::error::INTERNAL, + absl::StrCat("cublasLtMatmulPreferenceCreate failed: ", + ToString(status))); + } + UniqueMatmulPreference unique_preference(preference); + SE_RETURN_IF_ERROR(SetCublasLtAttr(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + max_workspace_bytes)); + + const auto &cuda_plan = *static_cast(plan); + if (cuda_plan.params().batch_count == 0) { + return unique_preference; + } + // This is a workaround for a known issue in cuBlasLt where the heuristic may + // in rare cases select an algo that does not support the specified stride. + // Specifying the alignment requirements manually like this avoids the issue. + auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) { + return (stride & -stride) * GetDataTypeSizeBytes(dtype); + }; + if (cuda_plan.params().stride_a) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_a, + cuda_plan.params().ab_type))); + } + if (cuda_plan.params().stride_b) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_b, + cuda_plan.params().ab_type))); + } + if (cuda_plan.params().stride_c) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); + } + if (cuda_plan.params().stride_c) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); + } + return unique_preference; +} + +} // namespace + +#endif // CUDA_VERSION >= 11000 + +port::StatusOr> +CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { +#if CUDA_VERSION >= 11000 + auto cuda_plan = std::make_unique(); + SE_RETURN_IF_ERROR(cuda_plan->init(p)); + return static_cast>( + std::move(cuda_plan)); +#else + return port::Status( + port::error::UNIMPLEMENTED, + "CreateBlasLtMatmulPlan is not supported with this version of CUDA"); +#endif +} + +port::StatusOr>> +CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count, + bool for_remainder_batch) { +#if CUDA_VERSION >= 11000 + SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference, + CreateCublasLtMatmulPreference(plan, max_workspace_size)); + + std::vector results(max_algorithm_count); + { + absl::MutexLock lock(&mu_); + + CHECK(blasLt_ != nullptr); + + gpu::ScopedActivateExecutorContext sac{parent_}; + + int found_algorithm_count = 0; + const auto &cuda_plan = *static_cast(plan); + const auto &a_desc = + for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc(); + const auto &b_desc = + for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc(); + const auto &c_desc = + for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc(); + const auto &d_desc = + for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc(); + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc, + preference.get(), max_algorithm_count, results.data(), + &found_algorithm_count); + if (status != CUBLAS_STATUS_SUCCESS) { + return port::Status( + port::error::INTERNAL, + absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ", + ToString(status))); + } + results.resize(found_algorithm_count); + } + + std::vector> out_algorithms; + out_algorithms.reserve(results.size()); + for (size_t i = 0; i < results.size(); ++i) { + const auto &result = results[i]; + if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos + out_algorithms.emplace_back(std::make_unique( + i, result.algo, result.workspaceSize)); + } + return out_algorithms; +#else // if CUDA_VERSION < 11000 + return port::Status( + port::error::UNIMPLEMENTED, + "GetBlasLtMatmulAlgorithms is not supported with this version of CUDA"); +#endif +} + +port::StatusOr>> +CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count) { + return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size, + max_algorithm_count); +} + +#if CUDA_VERSION >= 11000 +bool CUDABlas::DoBlasLtMatmulInternal( + Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const HostOrDeviceScalar &beta, + DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) { + const auto &cuda_plan = *static_cast(plan); + const auto &cuda_algo = + *static_cast(algorithm); + + if (alpha.data_type() != cuda_plan.scale_type() || + beta.data_type() != cuda_plan.scale_type()) { + VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do " + "not match plan: expected " + << cuda_plan.c_type() << ", got alpha=" << alpha.data_type() + << " beta=" << beta.data_type(); + return false; + } + if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; + return false; + } + bool is_pointer_mode_host = !alpha.is_pointer(); + if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) != + is_pointer_mode_host) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " + "pointer_mode for the given alpha/beta."; + return false; + } + if ((cuda_plan.params().epilogue == blas::Epilogue::kBias || + cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) != + (bias != nullptr)) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " + "epilogue for the given bias pointer."; + return false; + } + const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque() + : alpha.opaque_value(); + const void *beta_ptr = + beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value(); + + void *workspace = nullptr; + if (cuda_algo.workspace_size()) { + port::Status allocation_status = AllocateWorkspace( + &workspace, scratch_allocator, cuda_algo.workspace_size()); + if (!allocation_status.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) + << "Failed to allocate workspace for cublasLtMatmul algo with id: " + << cuda_algo.algo_id() << " requiring " + << cuda_algo.workspace_size() << " bytes of workspace"; + } + return false; + } + } + + // This is only used when batch_count > kMaxBatchCount. + std::unique_ptr unique_remainder_algo; + if (cuda_plan.remainder_batch_count()) { + // There is no easy way to get the user-specified max workspace size here, + // so we just allow a very small amount and don't worry too much about + // performance because this is only used in rare cases. The same reasoning + // applies to selection of the algorithm. + size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB + auto status_or_algorithms = + GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size, + /* max_algorithm_count = */ 1, + /* for_remainder_batch = */ true); + if (!status_or_algorithms.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch."; + } + return false; + } + auto algorithms = status_or_algorithms.ConsumeValueOrDie(); + unique_remainder_algo = std::move(algorithms.front()); + } + + cudaStream_t cuda_stream = CUDAStream(stream); + + absl::MutexLock lock(&mu_); + + if (bias != nullptr) { + if (!cuda_plan.SetBiasPointer(bias.opaque())) { + VLOG(2) << "DoBlasLtMatmul returning false because setting the bias " + "pointer failed."; + return false; + } + } + + CHECK(blasLt_ != nullptr); + + gpu::ScopedActivateExecutorContext sac{parent_}; + + // Plan execution is broken down into repeat calls with capped_batch_count, + // followed by a final call with remainder_batch_count. + // Cases where batch_count <= kMaxBatchCount require only a single call (a + // single loop iteration and no remainder). + int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type); + int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type); + const char *a_ptr = static_cast(a.opaque()); + const char *b_ptr = static_cast(b.opaque()); + const char *c_ptr = static_cast(c.opaque()); + char *d_ptr = static_cast(d.opaque()); + int capped_batch_count = cuda_plan.capped_batch_count(); + for (int batch = 0; + batch + capped_batch_count <= cuda_plan.params().batch_count; + batch += capped_batch_count) { + cublasStatus_t ret = cublasLtMatmul( + blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(), + b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr, + cuda_plan.d_desc(), cuda_algo.algo(), workspace, + cuda_algo.workspace_size(), cuda_stream); + if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret); + } + return false; + } + a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size; + b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size; + c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size; + d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size; + } + // This is only used when batch_count > kMaxBatchCount. + if (cuda_plan.remainder_batch_count()) { + const auto &remainder_algo = + *static_cast( + unique_remainder_algo.get()); + if (remainder_algo.workspace_size()) { + port::Status allocation_status = AllocateWorkspace( + &workspace, scratch_allocator, remainder_algo.workspace_size()); + if (!allocation_status.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo " + "with id: " + << remainder_algo.algo_id() << " requiring " + << remainder_algo.workspace_size() + << " bytes of workspace"; + } + return false; + } + } + cublasStatus_t ret = cublasLtMatmul( + blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, + cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(), + beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr, + cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace, + remainder_algo.workspace_size(), cuda_stream); + if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: " + << ToString(ret); + } + return false; + } + } + return true; +} +#endif // CUDA_VERSION >= 11000 + +bool CUDABlas::DoBlasLtMatmul( + Stream *stream, const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const HostOrDeviceScalar &beta, + DeviceMemoryBase c, ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, + blas::ProfileResult *output_profile_result) { +#if CUDA_VERSION >= 11000 + const auto &cuda_plan = *static_cast(plan); + HostOrDeviceScalar alpha_cast = alpha; + HostOrDeviceScalar beta_cast = beta; + if (cuda_plan.c_type() == blas::DataType::kHalf && + cuda_plan.scale_type() == blas::DataType::kFloat) { + // The given alpha and beta types are F16 (they always match c), but F32* + // computation type requires that they be F32, so we must cast them. + if (alpha.is_pointer() || beta.is_pointer()) { + // We cannot easily convert a pointer to f16 memory to a pointer to f32 + // memory from here, so we don't support this for now. + return false; + } + alpha_cast = HostOrDeviceScalar( + static_cast(alpha.value())); + beta_cast = + HostOrDeviceScalar(static_cast(beta.value())); + } + + std::unique_ptr timer; + if (output_profile_result) { + timer.reset(new GpuTimer(parent_)); + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return false; + } + } + + bool err_on_failure = timer != nullptr; + bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast, + a, b, beta_cast, c, c, scratch_allocator, + algorithm, bias); + + if (timer && result) { + // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error + // state. + if (!timer->Stop(AsGpuStream(stream))) { + return false; + } + output_profile_result->set_is_valid(true); + output_profile_result->set_algorithm(algorithm->index()); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return result; +#else // if CUDA_VERSION < 11000 + return false; +#endif +} + port::Status CUDABlas::GetVersion(std::string *version) { absl::MutexLock lock(&mu_); diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 9ff63102aaa..ca2aa15d938 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -21,7 +21,9 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_or_device_scalar.h" @@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport { // invoked before calling into cuBLAS. bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns the underlying CUDA stream. + cudaStream_t CUDAStream(Stream *stream); + // A helper function that calls the real cuBLAS function together with error // handling. // @@ -134,6 +139,24 @@ class CUDABlas : public blas::BlasSupport { const T &beta, DeviceMemory *y, int incy, blas::ProfileResult *output_profile_result); + // Helper function for implementing DoBlasLtMatmul. + bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure, + const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, + DeviceMemoryBase a, DeviceMemoryBase b, + const HostOrDeviceScalar &beta, + DeviceMemoryBase c, DeviceMemoryBase d, + ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, + DeviceMemoryBase bias); + + // Helper function for implementing GetBlasLtMatmulAlgorithms. + port::StatusOr>> + GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count, + bool for_remainder_batch = false); + // Guards the cuBLAS handle for this device. absl::Mutex mu_; @@ -144,6 +167,11 @@ class CUDABlas : public blas::BlasSupport { // cuBLAS library handle on the device. cublasHandle_t blas_ TF_GUARDED_BY(mu_); +#if CUDA_VERSION >= 11000 + // cuBLASLt library handle on the device. + cublasLtHandle_t blasLt_ GUARDED_BY(mu_); +#endif + SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); }; diff --git a/tensorflow/stream_executor/data_type.h b/tensorflow/stream_executor/data_type.h new file mode 100644 index 00000000000..a1ce663c51e --- /dev/null +++ b/tensorflow/stream_executor/data_type.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_ +#define TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_ + +#include + +#include "tensorflow/stream_executor/dnn.pb.h" +#include "tensorflow/stream_executor/platform/port.h" + +namespace Eigen { +struct half; +} // namespace Eigen + +namespace stream_executor { +namespace dnn { + +// A helper class to convert C/C++ types to the proper enums. +template +struct ToDataType; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kFloat; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kDouble; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kHalf; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kInt8; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kInt32; +}; +template <> +struct ToDataType> { + static constexpr DataType value = DataType::kComplexFloat; +}; +template <> +struct ToDataType> { + static constexpr DataType value = DataType::kComplexDouble; +}; + +} // namespace dnn +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_ diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 53cdff8cb7a..920f5fe246c 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" +#include "tensorflow/stream_executor/data_type.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dnn.pb.h" #include "tensorflow/stream_executor/lib/array_slice.h" @@ -110,30 +111,6 @@ enum class QuantizedActivationMode { k32Bit = 4, }; -// A helper class to convert C/C++ types to the proper enums. -template -struct ToDataType; -template <> -struct ToDataType { - static constexpr DataType value = DataType::kFloat; -}; -template <> -struct ToDataType { - static constexpr DataType value = DataType::kDouble; -}; -template <> -struct ToDataType { - static constexpr DataType value = DataType::kHalf; -}; -template <> -struct ToDataType { - static constexpr DataType value = DataType::kInt8; -}; -template <> -struct ToDataType { - static constexpr DataType value = DataType::kInt32; -}; - // Specifies the types of a RNN model. enum class RnnMode { kRnnRelu = 0, diff --git a/tensorflow/stream_executor/dnn.proto b/tensorflow/stream_executor/dnn.proto index 4d09e615e7d..f849b011eb3 100644 --- a/tensorflow/stream_executor/dnn.proto +++ b/tensorflow/stream_executor/dnn.proto @@ -12,6 +12,8 @@ enum DataType { kHalf = 2; kInt8 = 3; kInt32 = 4; + kComplexFloat = 5; + kComplexDouble = 6; } // Describes how a convolution input or output layer's data is formatted. diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc index 53f76503f2a..6127f644471 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.cc +++ b/tensorflow/stream_executor/gpu/asm_compiler.cc @@ -225,6 +225,21 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); if (exit_status != 0) { + // It happens when the ptxas installed is too old for the current GPU. + // Example error message associated with this error code: + // ptxas fatal : Value 'sm_80' is not defined for option 'gpu-name' + // In that case, fallback to the driver for compilation + if (absl::StartsWith(stderr_output, "ptxas fatal : Value '") && + absl::StrContains(stderr_output, + "is not defined for option 'gpu-name'")) { + LOG(WARNING) << "Your CUDA software stack is old. We fallback to the" + << " NVIDIA driver for some compilation. Update your CUDA" + << " version to get the best performance." + << " The ptxas error was: " << stderr_output; + return tensorflow::errors::Unimplemented( + ptxas_path, " ptxas too old. Falling back to the driver to compile."); + } + return port::InternalError( absl::StrFormat("ptxas exited with non-zero error code %d, output: %s", exit_status, stderr_output)); diff --git a/tensorflow/stream_executor/host_or_device_scalar.h b/tensorflow/stream_executor/host_or_device_scalar.h index 1f5d4b9260c..3274e7849fe 100644 --- a/tensorflow/stream_executor/host_or_device_scalar.h +++ b/tensorflow/stream_executor/host_or_device_scalar.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ #define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ +#include "tensorflow/stream_executor/data_type.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/platform/logging.h" @@ -23,6 +24,7 @@ namespace stream_executor { // Allows to represent a value that is either a host scalar or a scalar stored // on the GPU device. +// See also the specialization for ElemT=void below. template class HostOrDeviceScalar { public: @@ -52,5 +54,154 @@ class HostOrDeviceScalar { bool is_pointer_; }; +// Specialization for wrapping a dynamically-typed value (via type erasure). +template <> +class HostOrDeviceScalar { + public: + using DataType = dnn::DataType; + + // Constructors not marked as explicit because when using this constructor, we + // usually want to set this to a compile-time constant. + + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(float value) + : float_(value), is_pointer_(false), dtype_(DataType::kFloat) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(double value) + : double_(value), is_pointer_(false), dtype_(DataType::kDouble) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(Eigen::half value) + : half_(value), is_pointer_(false), dtype_(DataType::kHalf) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(int8 value) + : int8_(value), is_pointer_(false), dtype_(DataType::kInt8) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(int32 value) + : int32_(value), is_pointer_(false), dtype_(DataType::kInt32) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(std::complex value) + : complex_float_(value), + is_pointer_(false), + dtype_(DataType::kComplexFloat) {} + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(std::complex value) + : complex_double_(value), + is_pointer_(false), + dtype_(DataType::kComplexDouble) {} + template + explicit HostOrDeviceScalar(const DeviceMemory& pointer) + : pointer_(pointer), + is_pointer_(true), + dtype_(dnn::ToDataType::value) { + CHECK_EQ(1, pointer.ElementCount()); + } + // Construct from statically-typed version. + template ::value, + int>::type = 0> + // NOLINTNEXTLINE google-explicit-constructor + HostOrDeviceScalar(const HostOrDeviceScalar& other) { + if (other.is_pointer()) { + *this = HostOrDeviceScalar(other.pointer()); + } else { + *this = HostOrDeviceScalar(other.value()); + } + } + + bool is_pointer() const { return is_pointer_; } + template + const DeviceMemory& pointer() const { + CHECK(is_pointer()); + CHECK(dtype_ == dnn::ToDataType::value); + return pointer_; + } + template + const T& value() const { + CHECK(!is_pointer()); + CHECK(dtype_ == dnn::ToDataType::value); + return value_impl(); + } + const DeviceMemoryBase& opaque_pointer() const { + CHECK(is_pointer()); + return pointer_; + } + const void* opaque_value() const { + CHECK(!is_pointer()); + switch (dtype_) { + case DataType::kFloat: + return &float_; + case DataType::kDouble: + return &double_; + case DataType::kHalf: + return &half_; + case DataType::kInt8: + return &int8_; + case DataType::kInt32: + return &int32_; + case DataType::kComplexFloat: + return &complex_float_; + case DataType::kComplexDouble: + return &complex_double_; + default: + return nullptr; + } + } + DataType data_type() const { return dtype_; } + + private: + template + const T& value_impl() const; + + union { + float float_; + double double_; + Eigen::half half_; + int8 int8_; + int32 int32_; + std::complex complex_float_; + std::complex complex_double_; + DeviceMemoryBase pointer_; + }; + bool is_pointer_; + DataType dtype_; +}; + +template <> +inline const float& HostOrDeviceScalar::value_impl() const { + return float_; +} + +template <> +inline const double& HostOrDeviceScalar::value_impl() const { + return double_; +} + +template <> +inline const Eigen::half& HostOrDeviceScalar::value_impl() + const { + return half_; +} + +template <> +inline const int8& HostOrDeviceScalar::value_impl() const { + return int8_; +} + +template <> +inline const int32& HostOrDeviceScalar::value_impl() const { + return int32_; +} + +template <> +inline const std::complex& +HostOrDeviceScalar::value_impl>() const { + return complex_float_; +} + +template <> +inline const std::complex& +HostOrDeviceScalar::value_impl>() const { + return complex_double_; +} + } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ diff --git a/tensorflow/stream_executor/platform/default/dlopen_checker.cc b/tensorflow/stream_executor/platform/default/dlopen_checker.cc index b55c9f53793..7b38dfcfec0 100644 --- a/tensorflow/stream_executor/platform/default/dlopen_checker.cc +++ b/tensorflow/stream_executor/platform/default/dlopen_checker.cc @@ -23,6 +23,7 @@ namespace DsoLoader { port::Status TryDlopenCUDALibraries() { auto cudart_status = GetCudaRuntimeDsoHandle(); auto cublas_status = GetCublasDsoHandle(); + auto cublaslt_status = GetCublasLtDsoHandle(); auto cufft_status = GetCufftDsoHandle(); auto curand_status = GetCurandDsoHandle(); auto cusolver_status = GetCusolverDsoHandle(); @@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() { if (!cudart_status.status().ok() || !cublas_status.status().ok() || !cufft_status.status().ok() || !curand_status.status().ok() || !cusolver_status.status().ok() || !cusparse_status.status().ok() || - !cudnn_status.status().ok()) { + !cudnn_status.status().ok() || !cublaslt_status.status().ok()) { return port::Status(port::error::INTERNAL, absl::StrCat("Cannot dlopen all CUDA libraries.")); } else { diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index a78c738f32c..8b8cb2ff937 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -85,6 +85,10 @@ port::StatusOr GetCublasDsoHandle() { return GetDsoHandle("cublas", GetCublasVersion()); } +port::StatusOr GetCublasLtDsoHandle() { + return GetDsoHandle("cublasLt", GetCublasVersion()); +} + port::StatusOr GetCufftDsoHandle() { return GetDsoHandle("cufft", GetCufftVersion()); } @@ -161,6 +165,11 @@ port::StatusOr GetCublasDsoHandle() { return *result; } +port::StatusOr GetCublasLtDsoHandle() { + static auto result = new auto(DsoLoader::GetCublasLtDsoHandle()); + return *result; +} + port::StatusOr GetCurandDsoHandle() { static auto result = new auto(DsoLoader::GetCurandDsoHandle()); return *result; diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 91138f713bd..7f087349fcf 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -37,6 +37,7 @@ namespace DsoLoader { port::StatusOr GetCudaDriverDsoHandle(); port::StatusOr GetCudaRuntimeDsoHandle(); port::StatusOr GetCublasDsoHandle(); +port::StatusOr GetCublasLtDsoHandle(); port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCusolverDsoHandle(); @@ -72,6 +73,7 @@ namespace CachedDsoLoader { port::StatusOr GetCudaDriverDsoHandle(); port::StatusOr GetCudaRuntimeDsoHandle(); port::StatusOr GetCublasDsoHandle(); +port::StatusOr GetCublasLtDsoHandle(); port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCusolverDsoHandle(); diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index 5ddad13ddf9..2223cb9ad67 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -2540,6 +2540,32 @@ port::Status ROCMBlas::GetVersion(string *version) { return port::UnimplementedError(""); } +port::StatusOr> +ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { + return port::Status( + port::error::UNIMPLEMENTED, + "CreateBlasLtMatmulPlan is not supported with this version of ROCM"); +} + +port::StatusOr>> +ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count) { + return port::Status( + port::error::UNIMPLEMENTED, + "GetBlasLtMatmulAlgorithms is not supported with this version of ROCM"); +} + +bool ROCMBlas::DoBlasLtMatmul( + Stream *stream, const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, DeviceMemoryBase a, + DeviceMemoryBase b, const HostOrDeviceScalar &beta, + DeviceMemoryBase c, ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, + blas::ProfileResult *output_profile_result) { + return false; +} + } // namespace gpu void initialize_rocblas() { diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 46441d396b7..4ad9fc128cc 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4322,6 +4322,80 @@ Stream &Stream::ThenBlasGemmStridedBatched( c, ldc, stride_c, batch_count); } +template +Stream &Stream::ThenBlasLtMatmulImpl( + const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar &alpha, + const DeviceMemory &a, const DeviceMemory &b, + const HostOrDeviceScalar &beta, DeviceMemory *c, + ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, + const DeviceMemory &bias, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm), PARAM(bias)); + + ThenBlasWithProfileImpl< + const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar &, + const DeviceMemory &, const DeviceMemory &, + const HostOrDeviceScalar &, DeviceMemory *, + ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory &> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, bias, output_profile_result); +} + +// Explicit template instantiations for each supported type combination. +template Stream &Stream::ThenBlasLtMatmulImpl( + const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar &, + const DeviceMemory &, const DeviceMemory &, + const HostOrDeviceScalar &, DeviceMemory *, + ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory &, blas::ProfileResult *); + +template Stream &Stream::ThenBlasLtMatmulImpl( + const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar &, + const DeviceMemory &, const DeviceMemory &, + const HostOrDeviceScalar &, DeviceMemory *, + ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory &, blas::ProfileResult *); + +template Stream &Stream::ThenBlasLtMatmulImpl( + const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar &, + const DeviceMemory &, const DeviceMemory &, + const HostOrDeviceScalar &, DeviceMemory *, + ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory &, blas::ProfileResult *); + +template Stream &Stream::ThenBlasLtMatmulImpl( + const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar &, + const DeviceMemory &, const DeviceMemory &, + const HostOrDeviceScalar &, DeviceMemory *, + ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory &, blas::ProfileResult *); + +template Stream & +Stream::ThenBlasLtMatmulImpl, std::complex>( + const blas::IBlasLtMatmulPlan *, + const HostOrDeviceScalar> &, + const DeviceMemory> &, + const DeviceMemory> &, + const HostOrDeviceScalar> &, + DeviceMemory> *, ScratchAllocator *, + const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory> &, blas::ProfileResult *); + +template Stream & +Stream::ThenBlasLtMatmulImpl, std::complex>( + const blas::IBlasLtMatmulPlan *, + const HostOrDeviceScalar> &, + const DeviceMemory> &, + const DeviceMemory> &, + const HostOrDeviceScalar> &, + DeviceMemory> *, ScratchAllocator *, + const blas::IBlasLtMatmulAlgorithm *, + const DeviceMemory> &, blas::ProfileResult *); + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index df5acddccd5..b1460b02935 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -75,6 +75,19 @@ class AlgorithmDesc; class StreamExecutor; class ScratchAllocator; +namespace detail { + +// Helper class to prevent a template function argument from being deduced. This +// is identical to std::type_identity in C++20. +template +struct NonDeduced { + using type = T; +}; +template +using NonDeducedType = typename NonDeduced::type; + +} // namespace detail + // Convert a type to the corresponding QuantizedActivationMode. template struct Quantization; @@ -1632,6 +1645,25 @@ class Stream { const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb); + // See BlasSupport::DoBlatLtMatmul. + // Note that we prevent alpha and beta from being used to deduce CType so that + // they can be constructed implicitly from values of type CType. Without this, + // type deduction would fail when this function is called with a value of type + // CType for alpha or beta. + template + Stream &ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan *plan, + const detail::NonDeducedType> &alpha, + const DeviceMemory &a, const DeviceMemory &b, + const detail::NonDeducedType> &beta, + DeviceMemory *c, ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, + const DeviceMemory &bias = {}, + blas::ProfileResult *output_profile_result = nullptr) { + return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator, + algorithm, bias, output_profile_result); + } + // See FftSupport::DoFft. Stream &ThenFft(fft::Plan *plan, const DeviceMemory> &input, @@ -2064,6 +2096,19 @@ class Stream { const dnn::BatchDescriptor &bias_descriptor, DeviceMemory *backward_bias_data); + // Implementation of ThenBlasLtMatmul that is shared by all types. + template + Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan, + const HostOrDeviceScalar &alpha, + const DeviceMemory &a, + const DeviceMemory &b, + const HostOrDeviceScalar &beta, + DeviceMemory *c, + ScratchAllocator *scratch_allocator, + const blas::IBlasLtMatmulAlgorithm *algorithm, + const DeviceMemory &bias, + blas::ProfileResult *output_profile_result); + SE_DISALLOW_COPY_AND_ASSIGN(Stream); }; diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index c32793e3e83..3c6f70ae2b0 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -337,6 +337,30 @@ bool StreamExecutor::GetBlasGemmAlgorithms( return blas_support->GetBlasGemmAlgorithms(out_algorithms); } +port::StatusOr> +StreamExecutor::CreateBlasLtMatmulPlan( + const blas::BlasLtMatmulPlanParams ¶ms) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the blas implementation."); + } + return blas_support->CreateBlasLtMatmulPlan(params); +} + +port::StatusOr>> +StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, + int max_algorithm_count) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the blas implementation."); + } + return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size, + max_algorithm_count); +} + port::StatusOr> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index f19c76c3790..2ee583477b9 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -395,6 +395,21 @@ class StreamExecutor { // Get the list of supported algorithms for BLAS gemm. bool GetBlasGemmAlgorithms(std::vector *out_algorithms); + // Creates a backend-specific plan object for a blaslt matmul operation, which + // can then be passed to DoBlasLtMatmul(). When possible, plans should be + // created once and reused for multiple calls to DoBlasLtMatmul(). + // Returns a null pointer on failure. + port::StatusOr> + CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms); + + // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to an + // internal heuristic. The first returned algorithm can be used as the default + // algorithm if no autotuning is to be performed. + port::StatusOr>> + GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, + size_t max_workspace_size, int max_algorithm_count); + // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr> createRnnDescriptor( diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.cc b/tensorflow/stream_executor/tpu/c_api_conversions.cc index 0a7801f45fc..ba0c4c1c2a3 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/stream_executor/tpu/c_api_conversions.cc @@ -149,18 +149,171 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) { return base; } -xla::Shape FromC(const XLA_Shape* shape) { - xla::ShapeProto p; - p.ParseFromArray(shape->bytes, shape->size); - return xla::Shape(p); +// Helper functions for copying data to possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64 and int64_t. This should not be used with +// types that require a static_cast. +template +static void CopyVectorBase(const absl::Span src, DstList* dst) { + static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); + dst->size = src.size(); + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new Dst[dst->size]; + memcpy(dst->heap, src.data(), dst->size * sizeof(Src)); + } else { + memcpy(dst->inlined, src.data(), dst->size * sizeof(Src)); + } +} + +static void CopyVector(const absl::Span src, + Int64List* dst) { + return CopyVectorBase( + src, dst); +} +static void CopyVector(const absl::Span src, BoolList* dst) { + return CopyVectorBase(src, dst); +} + +static void CopyVector(const absl::Span src, TileList* dst) { + dst->size = src.size(); + XLA_Tile* c_tiles; + if (dst->size > TPU_C_API_MAX_INLINED) { + dst->heap = new XLA_Tile[dst->size]; + c_tiles = dst->heap; + } else { + c_tiles = dst->inlined; + } + for (int i = 0; i < dst->size; ++i) { + ToC(src[i], &c_tiles[i]); + } +} + +// Helper functions for creating a view of possibly-inlined C arrays. + +// 'Src' and 'Dst' are allowed to be different types to make this usable with +// memory-identical types, e.g. int64 and int64_t. This should not be used with +// types that require a static_cast. +template +static absl::Span MakeSpanBase(const SrcList& src_list) { + static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types"); + const Src* src = src_list.size > TPU_C_API_MAX_INLINED ? src_list.heap + : &src_list.inlined[0]; + return absl::Span(reinterpret_cast(src), + src_list.size); +} + +static absl::Span MakeSpan( + const Int64List& src_list) { + return MakeSpanBase(src_list); +} +static absl::Span MakeSpan(const BoolList& src_list) { + return MakeSpanBase(src_list); } void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { - xla::ShapeProto p = xla_shape.ToProto(); - std::string p_str = p.SerializeAsString(); - c_shape->bytes = new char[p_str.size()]; - c_shape->size = p_str.size(); - memcpy(c_shape->bytes, p_str.data(), p_str.size()); + c_shape->element_type = xla_shape.element_type(); + + CopyVector(xla_shape.dimensions(), &c_shape->dimensions); + CopyVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); + + c_shape->ntuple_shapes = xla_shape.tuple_shapes_size(); + if (c_shape->ntuple_shapes > 0) { + c_shape->tuple_shapes = new XLA_Shape[c_shape->ntuple_shapes]; + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + ToC(xla_shape.tuple_shapes(i), &c_shape->tuple_shapes[i]); + } + } + + if (xla_shape.has_layout()) { + ToC(xla_shape.layout(), &c_shape->layout); + } else { + c_shape->layout.format = xla::INVALID_FORMAT; + } +} + +xla::Shape FromC(const XLA_Shape* c_shape) { + absl::Span dims = + MakeSpan(c_shape->dimensions); + absl::Span dynamic_dims = MakeSpan(c_shape->dynamic_dimensions); + + std::vector tuple_shapes; + tuple_shapes.reserve(c_shape->ntuple_shapes); + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + tuple_shapes.push_back(FromC(&c_shape->tuple_shapes[i])); + } + + xla::Shape result(static_cast(c_shape->element_type), + dims, dynamic_dims, std::move(tuple_shapes)); + if (c_shape->layout.format != xla::INVALID_FORMAT) { + *result.mutable_layout() = FromC(&c_shape->layout); + } + return result; +} + +void Free(XLA_Shape* c_shape) { + if (c_shape->dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_shape->dimensions.heap; + } + if (c_shape->dynamic_dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_shape->dynamic_dimensions.heap; + } + if (c_shape->ntuple_shapes > 0) { + for (int i = 0; i < c_shape->ntuple_shapes; ++i) { + Free(&c_shape->tuple_shapes[i]); + } + delete[] c_shape->tuple_shapes; + } + if (c_shape->layout.format != xla::INVALID_FORMAT) { + Free(&c_shape->layout); + } +} + +void ToC(const xla::Layout& layout, XLA_Layout* c_layout) { + c_layout->format = layout.format(); + CopyVector(layout.minor_to_major(), &c_layout->minor_to_major); + c_layout->element_size_in_bits = layout.element_size_in_bits(); + c_layout->memory_space = layout.memory_space(); + CopyVector(layout.tiles(), &c_layout->tiles); +} + +xla::Layout FromC(const XLA_Layout* c_layout) { + absl::Span minor_to_major = + MakeSpan(c_layout->minor_to_major); + absl::InlinedVector tiles; + const XLA_Tile* c_tiles = c_layout->tiles.size > TPU_C_API_MAX_INLINED + ? c_layout->tiles.heap + : c_layout->tiles.inlined; + for (int i = 0; i < c_layout->tiles.size; ++i) { + tiles.push_back(FromC(&c_tiles[i])); + } + return xla::Layout(minor_to_major, tiles, c_layout->element_size_in_bits, + c_layout->memory_space); +} + +void Free(XLA_Layout* c_layout) { + if (c_layout->minor_to_major.size > TPU_C_API_MAX_INLINED) { + delete[] c_layout->minor_to_major.heap; + } + if (c_layout->tiles.size > TPU_C_API_MAX_INLINED) { + delete[] c_layout->tiles.heap; + } +} + +void ToC(const xla::Tile& tile, XLA_Tile* c_tile) { + CopyVector(tile.dimensions(), &c_tile->dimensions); +} + +xla::Tile FromC(const XLA_Tile* c_tile) { + absl::Span dims = + MakeSpan(c_tile->dimensions); + return xla::Tile(dims); +} + +void Free(XLA_Tile* c_tile) { + if (c_tile->dimensions.size > TPU_C_API_MAX_INLINED) { + delete[] c_tile->dimensions.heap; + } } XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) { @@ -212,7 +365,6 @@ void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) { } } -void Free(XLA_Shape* shape) { delete[] shape->bytes; } void Free(XLA_ShapeIndex* shape_index) { delete[] shape_index; } void Free(SE_DeviceMemoryBase*) {} diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.h b/tensorflow/stream_executor/tpu/c_api_conversions.h index c4b5648e097..da856a8720b 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.h +++ b/tensorflow/stream_executor/tpu/c_api_conversions.h @@ -43,9 +43,19 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); void Free(SE_DeviceMemoryBase*); // xla::Shape -xla::Shape FromC(const XLA_Shape* shape); +xla::Shape FromC(const XLA_Shape* c_shape); void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); -void Free(XLA_Shape* shape); +void Free(XLA_Shape* c_shape); + +// xla::Layout +xla::Layout FromC(const XLA_Layout* c_layout); +void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout); +void Free(XLA_Layout* c_layout); + +// xla::Tile +xla::Tile FromC(const XLA_Tile* c_tile); +void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); +void Free(XLA_Tile* c_tile); // xla::ShapeIndex XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape); diff --git a/tensorflow/stream_executor/tpu/c_api_decl.h b/tensorflow/stream_executor/tpu/c_api_decl.h index dcb53823e0c..1b92913263e 100644 --- a/tensorflow/stream_executor/tpu/c_api_decl.h +++ b/tensorflow/stream_executor/tpu/c_api_decl.h @@ -25,6 +25,9 @@ limitations under the License. extern "C" { +// Maximum number of array elements to inline into structs for performance. +#define TPU_C_API_MAX_INLINED 6 + enum TpuCoreTypeEnum { kTensorCore, kEmbeddingV1, @@ -168,11 +171,50 @@ typedef struct SE_MaybeOwningDeviceMemory { SE_DeviceMemoryAllocator allocator; } SE_MaybeOwningDeviceMemory; +struct Int64List { + union { + int64_t* heap; // owned + int64_t inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +struct BoolList { + union { + bool* heap; // owned + bool inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +typedef struct XLA_Tile { + Int64List dimensions; +} XLA_Tile; + +struct TileList { + union { + XLA_Tile* heap; // owned + XLA_Tile inlined[TPU_C_API_MAX_INLINED]; + }; + int64_t size; +}; + +typedef struct XLA_Layout { + int format; + Int64List minor_to_major; + TileList tiles; + int64_t element_size_in_bits; + int64_t memory_space; +} XLA_Layout; + // Represents an XLA shape tree. -// Shapes are flattened in default traversal order. typedef struct XLA_Shape { - char* bytes; - size_t size; + int element_type; + Int64List dimensions; + BoolList dynamic_dimensions; + XLA_Shape* tuple_shapes; // owned + int ntuple_shapes; + XLA_Layout layout; } XLA_Shape; // Represents a leaf node for a XLA shaped buffer. diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 075730742c6..4643af0e76f 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -38,7 +38,6 @@ load( "//third_party/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", - "if_mkl_lnx_x64", "if_mkl_ml", "mkl_deps", ) @@ -51,6 +50,7 @@ load( "//third_party/ngraph:build_defs.bzl", "if_ngraph", ) +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") # version for the shared libraries, can # not contain rc or alpha, only numbers. @@ -115,11 +115,6 @@ def tf_android_core_proto_headers(core_proto_sources_relative): for p in core_proto_sources_relative ]) -# Wrapper for portable protos which currently just creates an empty rule. -def tf_portable_proto_library(name, proto_deps, deps = [], **kwargs): - _ignore = [kwargs] - cc_library(name = name, deps = deps + [dep + "_cc" for dep in proto_deps]) - def tf_portable_full_lite_protos(full, lite): return select({ "//tensorflow:mobile_lite_protos": lite, @@ -267,6 +262,12 @@ def if_libtpu(if_true, if_false = []): "//conditions:default": if_false, }) +def if_registration_v2(if_true, if_false = []): + return select({ + "//tensorflow:registration_v2": if_true, + "//conditions:default": if_false, + }) + # Linux systems may required -lrt linker flag for e.g. clock_gettime # see https://github.com/tensorflow/tensorflow/issues/15129 def lrt_if_needed(): @@ -330,8 +331,7 @@ def tf_copts( if_libtpu(["-DLIBTPU_ON_GCE"], []) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + - if_mkl(["-DINTEL_MKL=1", "-DENABLE_MKLDNN_V1", "-DENABLE_INTEL_MKL_BFLOAT16"]) + - if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + + if_mkl(["-DINTEL_MKL=1", "-DENABLE_MKLDNN_V1", "-DENABLE_INTEL_MKL_BFLOAT16", "-DINTEL_MKL_DNN_ONLY"]) + if_mkldnn_threadpool(["-DENABLE_MKLDNN_THREADPOOL"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_ngraph(["-DINTEL_NGRAPH=1"]) + @@ -354,7 +354,12 @@ def tf_copts( ) def tf_openmp_copts(): - return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"])) + # We assume when compiling on Linux gcc/clang will be used and MSVC on Windows + return select({ + "@org_tensorflow//third_party/mkl:build_with_mkl_lnx_openmp": ["-fopenmp"], + "@org_tensorflow//third_party/mkl:build_with_mkl_windows_openmp": ["/openmp"], + "//conditions:default": [], + }) def tf_opts_nortti(): return [ @@ -1564,7 +1569,7 @@ def tf_mkl_kernel_library( hdrs = hdrs, deps = deps, alwayslink = alwayslink, - copts = copts, + copts = copts + if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]), features = disable_header_modules, ) @@ -2216,6 +2221,7 @@ def gpu_py_test( xla_enable_strict_auto_jit = False, xla_enabled = False, grpc_enabled = False, + xla_tags = [], # additional tags for xla_gpu tests **kwargs): if main == None: main = name + ".py" @@ -2238,7 +2244,7 @@ def gpu_py_test( kernels = kernels, main = main, shard_count = shard_count, - tags = test_tags + ["xla", "manual"], + tags = test_tags + xla_tags + ["xla", "manual"], xla_enabled = xla_enabled, xla_enable_strict_auto_jit = True, **kwargs @@ -2828,3 +2834,66 @@ def internal_tfrt_deps(): def internal_cuda_deps(): return [] + +def _tf_gen_options_header_impl(ctx): + header_depset = depset([ctx.outputs.output_header]) + + define_vals = {True: "true", False: "false"} + substitutions = {} + for target, identifier in ctx.attr.build_settings.items(): + setting_val = target[BuildSettingInfo].value + lines = [ + "// %s" % target.label, + "#define TF_OPTION_%s() %s" % (identifier, define_vals[setting_val]), + ] + substitutions["#define_option %s" % identifier] = "\n".join(lines) + + ctx.actions.expand_template( + template = ctx.file.template, + output = ctx.outputs.output_header, + substitutions = substitutions, + ) + + return [ + DefaultInfo(files = header_depset), + ] + +tf_gen_options_header = rule( + attrs = { + "output_header": attr.output( + doc = "File path for the generated header (output)", + mandatory = True, + ), + "template": attr.label( + doc = """Template for the header. + For each option name 'X' (see build_settings attribute), + '#define_option X' results in a macro 'TF_OPTION_X()' + """, + allow_single_file = True, + mandatory = True, + ), + "build_settings": attr.label_keyed_string_dict( + doc = """Dictionary from build-setting labels to option names. Example: + {"//tensorflow:x_setting" : "X"} + """, + providers = [BuildSettingInfo], + ), + }, + implementation = _tf_gen_options_header_impl, + doc = """ + Generates a header file for Bazel build settings. + + This is an alternative to setting preprocessor defines on the compiler + command line. It has a few advantages: + - Usage of the options requires #include-ing the header, and thus a + Bazel-level dependency. + - Each option has a definition site in source code, which mentions the + corresponding Bazel setting. This is particularly useful when + navigating code with the assistance of static analysis (e.g. + https://cs.opensource.google/tensorflow). + - Each option is represented as a FUNCTION()-style macro, which is always + defined (i.e. one uses #if instead of #ifdef). This allows forms like + 'if constexpr (TF_OPTION_FOO()) { ... }', and helps catch missing + dependencies (if 'F' is undefined, '#if F()' results in an error). + """, +) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt index 0a8e0b4421a..c35e409975d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -24,10 +24,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt index b5ccb39d075..68f62ba9e0b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt index 6a7a3a97aa0..df278d5cca5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt @@ -1,6 +1,10 @@ path: "tensorflow.distribute.InputReplicationMode" tf_class { is_instance: "" + member { + name: "PER_REPLICA" + mtype: "" + } member { name: "PER_WORKER" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt index 1a039b10501..266447848d0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt index 7876166dc40..55939070e23 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt @@ -9,11 +9,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -25,10 +25,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt index 0101212e4cc..b364014e55a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.distribute.ReplicaContext" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "devices" @@ -24,7 +25,7 @@ tf_class { } member_method { name: "all_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "merge_call" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt index b2d9d4ee2cb..7d7efcca81e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt @@ -37,7 +37,7 @@ tf_class { } member_method { name: "batch_reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "broadcast_to" @@ -69,7 +69,7 @@ tf_class { } member_method { name: "reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-communication.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-communication.pbtxt index 7eca1c80d8b..8803fbfea0b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-communication.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-communication.pbtxt @@ -1,16 +1,16 @@ path: "tensorflow.distribute.experimental.CollectiveCommunication" tf_class { - is_instance: "" + is_instance: "" member { name: "AUTO" - mtype: "" + mtype: "" } member { name: "NCCL" - mtype: "" + mtype: "" } member { name: "RING" - mtype: "" + mtype: "" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-implementation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-implementation.pbtxt new file mode 100644 index 00000000000..0ce1ded9192 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-implementation.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.distribute.experimental.CommunicationImplementation" +tf_class { + is_instance: "" + member { + name: "AUTO" + mtype: "" + } + member { + name: "NCCL" + mtype: "" + } + member { + name: "RING" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-options.pbtxt new file mode 100644 index 00000000000..dfcb8954ac0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-communication-options.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distribute.experimental.CommunicationOptions" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplementation.AUTO\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index 598a9dbc15b..0c08a47c72b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -18,7 +18,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], " + argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplementation.AUTO\', \'None\'], " } member_method { name: "colocate_vars_with" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt index 9247db37925..edac99871a8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt @@ -12,6 +12,14 @@ tf_module { name: "CollectiveHints" mtype: "" } + member { + name: "CommunicationImplementation" + mtype: "" + } + member { + name: "CommunicationOptions" + mtype: "" + } member { name: "MultiWorkerMirroredStrategy" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt index c3a84b15dd6..dc60db610ff 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt @@ -20,4 +20,8 @@ tf_module { name: "output_all_intermediates" argspec: "args=[\'state\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "register_filesystem_plugin" + argspec: "args=[\'plugin_location\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index da08722a7a3..41ab38df985 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -128,6 +136,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -298,7 +310,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index 1719c8bd9c7..48495e4ed13 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -14,6 +14,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -22,6 +26,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -130,6 +138,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -316,7 +328,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt index d93c018b073..1a23072830b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt @@ -13,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -21,6 +25,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -129,6 +137,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -299,7 +311,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 1b50d327cd6..93dd8fa9972 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt index 7150f2bd928..97f2e75199a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 9fba915d01a..3e5cee9e574 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -13,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -21,6 +25,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -129,6 +137,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -299,7 +311,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt index 51277dfae56..5495727a331 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -115,6 +123,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt index 378f6568eef..610fe3840a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt index a9d11967feb..5a5539404dd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt index fca5d2928ee..2ec76e78212 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt index 0c85d31934a..b1780f501ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt index 4ca1dc4a217..b7053964818 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt index b2c3156cf7a..9481cbc18f4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt index ae64e051158..4fc49aa433a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt index fd77d449216..ec5df5c667a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt index fc39c337669..4e3d41ee5dd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt index cbcfbb1022f..2b092efb833 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt index bbb6c19bd7f..81124e1682b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt index 16d329f22c1..e861fda8107 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 56cb840bd0b..161e0372d63 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt index 6dc759e1338..b3d992d5fdc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt index a619ae0a480..7e926162ade 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "constraints" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -112,6 +120,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt index 237d4e7f34c..e716c8c4921 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index dc15a2c227c..3c34852d7e6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -29,6 +29,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "data_format" mtype: "" @@ -45,6 +49,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -197,6 +205,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt index 3d3ee3c67bf..8c4c827b92b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt index 23fb7bfc4eb..b503a8c75f6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index a7eeb12ef04..43e6f7f332e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt index 9f4aa3cd95f..2a33509d196 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index a83cbc24972..9a1fc4949e7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt index 7ccbb9a2694..d6c0ad0d151 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt index 0733557f70d..66cdc51c355 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt index 71c2e77e7ff..517f2505bdb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index 824bd8bbb2f..a609dcdb55f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt index ac9d5be1883..10adc4ce928 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index ed63b8d98d4..ea0ee3028b3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt index d00f7a5b396..6e5d6bc942e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt index 2ca122485d6..c366341d8a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt index 1b69967a59e..ce86e9a44b4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt index 265f13b06bd..d46d79f30de 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt index e40eb1470a7..e3a3e3d9be5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt @@ -17,10 +17,18 @@ tf_class { name: "cell" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -117,6 +125,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt index 167a4d9e96f..33d3999c0a0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt @@ -17,10 +17,18 @@ tf_class { name: "cell" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -117,6 +125,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt index 59793ff6d45..7dd3f40faaf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt index 8370406e34d..bc7ddcea718 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 554d7531912..891a5b201b2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt index 101719437d4..5ba7cd1a385 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt index d441302523f..66045e97b8f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt index a736e4c03fd..8ba42c66f06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt index 3f9002792d8..5f5149dc5c2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt index 217f2701f3d..8bdcee57fa2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt index a4f9e447bdf..db3514efe73 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt index f0e84b8edd7..de7ba971a32 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -36,6 +40,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -180,6 +188,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 1962f3284f0..6273d7986a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt index 64073b27c24..e6debc89415 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index 73ed7f59394..08c9a100161 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index 2fb47e8a5a6..d6e3f1252cb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index e3ab2b5ab6d..94e9a96b88f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 494e2247fbf..00afd5fc967 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 22e79311d67..18d826fdd13 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index 83f91393647..92e8876f5c4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d211683ae9c..f9b8b1778f4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index 2b9442dee85..4bfcac30602 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index e02b42bdd0e..0762ac1e767 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 60d2a947d87..69ff605469c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 352527ea0f4..e05fb04f0b2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index a1ff2f402a7..1268fbf60b0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt index d1811a28b55..774a1c23255 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 13fbf554fd3..639efba7a75 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt index ff9e8b6df74..a1536d0b82e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -36,6 +40,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -180,6 +188,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt index c9f01c56606..211c366d9ea 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt index 3b9306cdfe6..3c293f42da3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer-normalization.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt index 03902ed1de4..962b7891b90 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt @@ -10,10 +10,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -106,6 +114,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index bf98a150184..3bd0b6a6071 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt index 040230d63b3..0d178d92f06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt index 8d49e7a58a1..da65b262707 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt index 485ae3b16ef..8dcd821c0e1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt index 05050fdbffa..cce7e92fde3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt index 8ae6a0ab43b..390cb52e228 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt index ae8aea28552..4c2ddc2ce12 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 94d2e0e6f6e..e7956c1029f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 91b0b44ea50..014af0a3b63 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 587850f1d6c..7c1a8e8820e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt index ac97ca6e061..52286c20c58 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt index 7c8950ce3fa..3ae00343db3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt index 89fbb32194a..001a8743a91 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt index 9ef978eeb3a..1b9ec5357fe 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt index 19a48d77113..6b26c702b12 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt index 03d5a2195cd..c79315618a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt index c8c5b8326dd..c09439b7028 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt index 84530e067b2..4d9f2c18b80 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt index 4de5b1c20d8..a9ab889dd55 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt index b9fcb027aaf..4ac4c4fe8f6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 5b6bff9dc5f..664b4ec2547 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 3fb3c032a3e..6ab78380d99 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 5387a8e5fc5..5e5aff38b19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index de2d3eaaab4..7560106c96e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 80e17948612..4b915342922 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 48e0c26b010..f5b25907cc6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -36,6 +40,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -168,6 +176,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt index 97e4b91bfa3..aebd00c9d98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index 0260221d093..415259bfbcf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index ddad5641e76..59c31486278 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 47e6ba9abfa..760e6592808 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 5379da642ed..bd454ac60ec 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -115,6 +123,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt index 1e070fb36db..e9c00b51415 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index 6d7724bdfe3..407c860c033 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt index d740fc8de3a..d0f7619d1bf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt index de377d9d2eb..7561273844e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt index e2ee7941662..428f43d7523 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt index 8dd967cd3ce..482ecf32e2a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt index 334463a4031..e902ff4a483 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt index ba4e58e3dfb..2312af6ebac 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt index 8538d903fba..8937e9de579 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 2b8681ae8cc..795ad24a3ac 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt index a2d7d285409..0242b10c649 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt index e24ca0dc01a..425c5e45934 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index ceb38316d11..1b5f2fc7a10 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt index 04e59727b19..79946192d5b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt index 14d43cb08e8..d09d0f85402 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt index cb7a793f94d..628f76c84a3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt index 75a1efc2f15..2c9af8c3b32 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt index 75625d24d30..464ca87b9bc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt @@ -17,10 +17,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -113,6 +121,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt index 093a2b2292e..916149fa9f1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt index 0fa0355b0f2..62e1c15af95 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt index 8eca3903616..f5c73972d6a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt index ad813468f53..55280e81038 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt index 15406e778f8..69fb7bea570 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt index 8119cb9687f..80d7fda8d0d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index 7c68f9ef783..131bc4b3efd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index 73866dfcf50..c212112a64f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt index 0d113434d80..de032c6566c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index 392e2efef39..ab55f40e8fc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index c8fcedd3221..96882e04598 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt index 7efb8d72dcb..3dc758729eb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt index 2d7e71c2c43..515f976574b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt @@ -17,10 +17,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -113,6 +121,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt index 532f98fb322..69ae8d67722 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt index 96703a02f88..b0d2b891dbc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -112,6 +120,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt index e1b19dc5836..24026cc795b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt index 64edac45aa4..0cd26f8c214 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt index 8130de478cd..1055718babd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt index 7bfcf1f2bdd..29d1597747a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt index 4e37d89dba7..ff9a7004c4e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt index 9bb3b783982..43b95f6d23b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt index 4defb3657ec..270bd27279f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-cosine-similarity.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt index 628363c3347..0c88ab6d8b5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt index 73c30e9a080..b2a8cbe0703 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt index fe3b145999f..7aaa8e69472 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt index e167d1a743c..c581d3ba734 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-k-l-divergence.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt index 83d2c77a759..6f561ab04a3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-log-cosh-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt index 90f18e32247..47572bb065e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt index cab559c90cb..b91104ea20b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt index 091cf58e3e2..5c5e0992d47 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-io-u.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt index faf78e01f38..5376ff099c3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-relative-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt index d4cb6e5b7ad..eb37f65f3a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt index 32a1cffa204..5a411922095 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt index 788f2e316b0..fd85ec63408 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean-tensor.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "count" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -116,6 +124,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt index 76981e1950c..d0bb8dfc079 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt index 65c65f71f6a..c9c15d8485c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-metric.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt index 22ba826b26e..a3676a52d84 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-poisson.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt index c6bd13d4f1e..8a4ad3060e4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision-at-recall.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt index eaf6bbeb78d..bd55beef8aa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt index 7254e26db7f..2f91c365dff 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall-at-precision.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt index 1f0d977ea23..ab1e70387ad 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt index c6f2ea193d4..1bab1fb18a1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt index 1c383c8dcf0..827f812b6c0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt index 11c694eb234..ad026eee5d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt index c95806164ac..7b2c8aad4f5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt index 01b883278e5..9cb2fb31167 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt index 3d90bfc1ac9..cae343a0b79 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt index 25858b461b3..d947aa57cc4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-squared-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt index 065036aee12..90c6ec1b640 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sum.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt index 2fac29e2415..561a13f6625 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt index 8fbf24a66ad..262ec6a3ed0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt index 4d1f2fd213f..af3f67971f1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..3bea3a9f8fe --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt @@ -0,0 +1,116 @@ +path: "tensorflow.keras.mixed_precision.LossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "clipnorm" + mtype: "" + } + member { + name: "clipvalue" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } + member { + name: "global_clipnorm" + mtype: "" + } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "inner_optimizer" + mtype: "" + } + member { + name: "iterations" + mtype: "" + } + member { + name: "loss_scale" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'inner_optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_scaled_loss" + argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_unscaled_gradients" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index 3c016d331de..a68c17e8f24 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -13,10 +14,30 @@ tf_class { name: "clipvalue" mtype: "" } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } member { name: "global_clipnorm" mtype: "" } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "inner_optimizer" + mtype: "" + } member { name: "iterations" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt index e8648afb5f7..502aef38b75 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.mixed_precision" tf_module { + member { + name: "LossScaleOptimizer" + mtype: "" + } member { name: "experimental" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index 15c0ab5abbb..0c837d030c3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -128,6 +136,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -298,7 +310,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 729fdd660ca..4a595abe9db 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -14,6 +14,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -22,6 +26,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -130,6 +138,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -316,7 +328,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt index ac80126aaa3..798dd83194d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt @@ -30,6 +30,6 @@ tf_module { } member_method { name: "save_model" - argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt index 0fa2415c54f..301ef87309b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt index 871adc80846..c549934cd98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt index 9b67ca9bee9..44fa4b26250 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index 905b8c028bb..944e372cb6f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt index ee778d9bead..13e5dce8ccf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt index 7c6958e4d9f..7a02a74cac1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -119,6 +127,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt index e88b843679c..cf0049e0762 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt index f31110ba8fe..e8801dc63fc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -119,6 +127,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt index a3f7e557ddc..79551b4dc93 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt index 930706bec9d..7543d366593 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -117,6 +125,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt index bb57573198a..f44894a17c2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -117,6 +125,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt index 1dff295f333..caa74ad4f97 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -117,6 +125,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt index 108d1b94bc6..6795c5357c3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -115,6 +123,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt index b65100b2ace..a43f13dd06a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt index 8b44d2eca70..4f0899757e5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt index 2f2b7e8830b..694624eb931 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt index ae06fe06994..c58d180acf7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -119,6 +127,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt index 330528e2df6..d6733740a2f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -119,6 +127,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index fc8f8cb0478..de4b1027f89 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index 9cd3bbd8732..e9a1481b03a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt index 1b865125114..2b2cbe615c3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -144,6 +144,10 @@ tf_module { name: "erfc" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "erfcinv" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "erfinv" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt new file mode 100644 index 00000000000..c744ae30e11 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.mixed_precision.DynamicLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "increment_period" + mtype: "" + } + member { + name: "initial_loss_scale" + mtype: "" + } + member { + name: "multiplier" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_loss_scale\', \'increment_period\', \'multiplier\'], varargs=None, keywords=None, defaults=[\'32768\', \'2000\', \'2.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt new file mode 100644 index 00000000000..7393181eb85 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.mixed_precision.FixedLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loss_scale_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt new file mode 100644 index 00000000000..044b49a9999 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.mixed_precision.LossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-mixed-precision-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-mixed-precision-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..f1e49106bf3 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-mixed-precision-loss-scale-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.mixed_precision.MixedPrecisionLossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt index 475c4a2ccde..b9020a1d912 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt @@ -1,7 +1,31 @@ path: "tensorflow.mixed_precision" tf_module { + member { + name: "DynamicLossScale" + mtype: "" + } + member { + name: "FixedLossScale" + mtype: "" + } + member { + name: "LossScale" + mtype: "" + } + member { + name: "MixedPrecisionLossScaleOptimizer" + mtype: "" + } member { name: "experimental" mtype: "" } + member_method { + name: "disable_mixed_precision_graph_rewrite" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "enable_mixed_precision_graph_rewrite" + argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], " + } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index f69f9d6fe7d..91e2a674648 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index c9d09d08a67..956ba36fbf1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 41bafa62ff9..bc842ea559a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -127,6 +135,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index e2ddb7feafe..bb1577af51b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -127,6 +135,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 11f02509c12..f2068c6a4b3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index d8f03cda8b5..a009efb9bd6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -126,6 +134,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index de531d7a11a..6acd3afd1c1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -125,6 +133,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index d678fae51f1..50926b63481 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -124,6 +132,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index b4449602033..823a3364b17 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -127,6 +135,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2efd289c259..96be23b9e50 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3212,6 +3212,10 @@ tf_module { name: "RaggedTensorToVariant" argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "RaggedTensorToVariantGradient" + argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "RandomCrop" argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adagrad-parameters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adagrad-parameters.pbtxt index 5177e95357a..f56ef6540fc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adagrad-parameters.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adagrad-parameters.pbtxt @@ -5,6 +5,6 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adam-parameters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adam-parameters.pbtxt index 68c62d5398b..922f4c3adc1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adam-parameters.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-adam-parameters.pbtxt @@ -5,6 +5,6 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-ftrl-parameters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-ftrl-parameters.pbtxt index 450015c3695..648c8340b70 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-ftrl-parameters.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-ftrl-parameters.pbtxt @@ -5,6 +5,6 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\'], " + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'multiply_linear_by_learning_rate\', \'beta\', \'allow_zero_accumulator\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'True\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0\', \'False\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-stochastic-gradient-descent-parameters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-stochastic-gradient-descent-parameters.pbtxt index 1b0bcf8be72..70b886b0324 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-stochastic-gradient-descent-parameters.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.-stochastic-gradient-descent-parameters.pbtxt @@ -5,6 +5,6 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'clip_gradient_min\', \'clip_gradient_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt index 0a8e0b4421a..c35e409975d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -24,10 +24,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt index b5ccb39d075..68f62ba9e0b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt index c3beabd938e..123e29cb163 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-options.pbtxt @@ -3,10 +3,18 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "experimental_place_dataset_on_device" + mtype: "" + } member { name: "experimental_prefetch_to_device" mtype: "" } + member { + name: "experimental_replication_mode" + mtype: "" + } member_method { name: "__init__" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt index 6a7a3a97aa0..df278d5cca5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt @@ -1,6 +1,10 @@ path: "tensorflow.distribute.InputReplicationMode" tf_class { is_instance: "" + member { + name: "PER_REPLICA" + mtype: "" + } member { name: "PER_WORKER" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 60c0e8d7663..75b0dd33fd3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-multi-worker-mirrored-strategy.pbtxt new file mode 100644 index 00000000000..02816f1c5f6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-multi-worker-mirrored-strategy.pbtxt @@ -0,0 +1,91 @@ +path: "tensorflow.distribute.MultiWorkerMirroredStrategy" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "cluster_resolver" + mtype: "" + } + member { + name: "extended" + mtype: "" + } + member { + name: "num_replicas_in_sync" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cluster_resolver\', \'communication_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "colocate_vars_with" + argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "configure" + argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "experimental_distribute_datasets_from_function" + argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_local_results" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_run" + argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "group" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_dataset_iterator" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_input_fn_iterator" + argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], " + } + member_method { + name: "reduce" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], " + } + member_method { + name: "scope" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "unwrap" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_config_proto" + argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt index 1a039b10501..266447848d0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index 17af12cf279..559ee5e9519 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt index 7876166dc40..55939070e23 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt @@ -9,11 +9,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -25,10 +25,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt index 0101212e4cc..7379ddc856d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.distribute.ReplicaContext" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "devices" @@ -22,9 +23,13 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "all_gather" + argspec: "args=[\'self\', \'value\', \'axis\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "all_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "merge_call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt index 75e34579e5c..a1447c2ded0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt @@ -20,7 +20,7 @@ tf_class { } member_method { name: "batch_reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "colocate_vars_with" @@ -28,7 +28,7 @@ tf_class { } member_method { name: "reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 702cdc98e88..5991d60fd81 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -51,6 +51,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt index 0cb06ec3b01..3d8e791613a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt @@ -64,6 +64,10 @@ tf_class { name: "experimental_split_to_logical_devices" argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index e1b92b44b73..00d7d652a89 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-communication.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-communication.pbtxt index 7eca1c80d8b..8803fbfea0b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-communication.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-communication.pbtxt @@ -1,16 +1,16 @@ path: "tensorflow.distribute.experimental.CollectiveCommunication" tf_class { - is_instance: "" + is_instance: "" member { name: "AUTO" - mtype: "" + mtype: "" } member { name: "NCCL" - mtype: "" + mtype: "" } member { name: "RING" - mtype: "" + mtype: "" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-implementation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-implementation.pbtxt new file mode 100644 index 00000000000..0ce1ded9192 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-implementation.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.distribute.experimental.CommunicationImplementation" +tf_class { + is_instance: "" + member { + name: "AUTO" + mtype: "" + } + member { + name: "NCCL" + mtype: "" + } + member { + name: "RING" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-options.pbtxt new file mode 100644 index 00000000000..dfcb8954ac0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-communication-options.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distribute.experimental.CommunicationOptions" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'bytes_per_pack\', \'timeout_seconds\', \'implementation\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'CommunicationImplementation.AUTO\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index fda255eac17..78f189a0d54 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -1,5 +1,6 @@ path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy" tf_class { + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -18,7 +19,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CollectiveCommunication.AUTO\', \'None\'], " + argspec: "args=[\'self\', \'communication\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'CommunicationImplementation.AUTO\', \'None\'], " } member_method { name: "colocate_vars_with" @@ -52,6 +53,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index e445d7c4dab..9393d410228 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.distribute.experimental.ParameterServerStrategy" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -18,7 +18,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'cluster_resolver\', \'variable_partitioner\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "colocate_vars_with" @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index 6536eefe414..25c525f8e18 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -52,6 +52,10 @@ tf_class { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "gather" + argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group" argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-cluster-coordinator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-cluster-coordinator.pbtxt new file mode 100644 index 00000000000..3381eb1036a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-cluster-coordinator.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.distribute.experimental.coordinator.ClusterCoordinator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "strategy" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'strategy\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "create_per_worker_dataset" + argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "done" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "fetch" + argspec: "args=[\'self\', \'val\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "join" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "schedule" + argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-per-worker-values.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-per-worker-values.pbtxt new file mode 100644 index 00000000000..5831f47dcfd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-per-worker-values.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distribute.experimental.coordinator.PerWorkerValues" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-remote-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-remote-value.pbtxt new file mode 100644 index 00000000000..f0e32de4aa1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.-remote-value.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.distribute.experimental.coordinator.RemoteValue" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "fetch" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.pbtxt new file mode 100644 index 00000000000..e394c0195c2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.coordinator.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.distribute.experimental.coordinator" +tf_module { + member { + name: "ClusterCoordinator" + mtype: "" + } + member { + name: "PerWorkerValues" + mtype: "" + } + member { + name: "RemoteValue" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-fixed-shards-partitioner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-fixed-shards-partitioner.pbtxt new file mode 100644 index 00000000000..622d1666ff3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-fixed-shards-partitioner.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.distribute.experimental.partitioners.FixedShardsPartitioner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_shards\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-max-size-partitioner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-max-size-partitioner.pbtxt new file mode 100644 index 00000000000..0dab66b68b5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-max-size-partitioner.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.distribute.experimental.partitioners.MaxSizePartitioner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_shard_bytes\', \'max_shards\', \'bytes_per_string\'], varargs=None, keywords=None, defaults=[\'None\', \'16\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-min-size-partitioner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-min-size-partitioner.pbtxt new file mode 100644 index 00000000000..f20b23aeffa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-min-size-partitioner.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.distribute.experimental.partitioners.MinSizePartitioner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'min_shard_bytes\', \'max_shards\', \'bytes_per_string\'], varargs=None, keywords=None, defaults=[\'262144\', \'1\', \'16\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-partitioner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-partitioner.pbtxt new file mode 100644 index 00000000000..cca4cd510c9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.-partitioner.pbtxt @@ -0,0 +1,8 @@ +path: "tensorflow.distribute.experimental.partitioners.Partitioner" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.pbtxt new file mode 100644 index 00000000000..70d36b14b64 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.partitioners.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.distribute.experimental.partitioners" +tf_module { + member { + name: "FixedShardsPartitioner" + mtype: "" + } + member { + name: "MaxSizePartitioner" + mtype: "" + } + member { + name: "MinSizePartitioner" + mtype: "" + } + member { + name: "Partitioner" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt index 06151eee4b4..77b3c5caebc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt @@ -13,9 +13,17 @@ tf_module { mtype: "" } member { - name: "MultiWorkerMirroredStrategy" + name: "CommunicationImplementation" + mtype: "" + } + member { + name: "CommunicationOptions" mtype: "" } + member { + name: "MultiWorkerMirroredStrategy" + mtype: "" + } member { name: "ParameterServerStrategy" mtype: "" @@ -28,4 +36,12 @@ tf_module { name: "ValueContext" mtype: "" } + member { + name: "coordinator" + mtype: "" + } + member { + name: "partitioners" + mtype: "" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt index d3867889a4f..9bd37181958 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt @@ -36,6 +36,10 @@ tf_module { name: "MirroredStrategy" mtype: "" } + member { + name: "MultiWorkerMirroredStrategy" + mtype: "" + } member { name: "NcclAllReduce" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt index 58384846276..33c28d715a5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt @@ -28,4 +28,8 @@ tf_module { name: "function_executor_type" argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "register_filesystem_plugin" + argspec: "args=[\'plugin_location\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index da08722a7a3..41ab38df985 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -128,6 +136,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -298,7 +310,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index 1719c8bd9c7..48495e4ed13 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -14,6 +14,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -22,6 +26,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -130,6 +138,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -316,7 +328,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt index d93c018b073..1a23072830b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt @@ -13,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -21,6 +25,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -129,6 +137,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -299,7 +311,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt index 1b50d327cd6..93dd8fa9972 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt index 7150f2bd928..97f2e75199a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt index 9fba915d01a..3e5cee9e574 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt @@ -13,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -21,6 +25,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -129,6 +137,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -299,7 +311,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt index 51277dfae56..5495727a331 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-abstract-r-n-n-cell.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -115,6 +123,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt index 378f6568eef..610fe3840a9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt index a9d11967feb..5a5539404dd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt index fca5d2928ee..2ec76e78212 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt index 0c85d31934a..b1780f501ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt index 4ca1dc4a217..b7053964818 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt index b2c3156cf7a..9481cbc18f4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt index ae64e051158..4fc49aa433a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt index fd77d449216..ec5df5c667a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt index fc39c337669..4e3d41ee5dd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt index cbcfbb1022f..2b092efb833 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt index bbb6c19bd7f..81124e1682b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt index 16d329f22c1..e861fda8107 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 56cb840bd0b..161e0372d63 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt index fd130c55979..d50d4bcf5c4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt index a619ae0a480..7e926162ade 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "constraints" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -112,6 +120,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt index 237d4e7f34c..e716c8c4921 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index dc15a2c227c..3c34852d7e6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -29,6 +29,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "data_format" mtype: "" @@ -45,6 +49,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -197,6 +205,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt index 3d3ee3c67bf..8c4c827b92b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt index 23fb7bfc4eb..b503a8c75f6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index a7eeb12ef04..43e6f7f332e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt index 9f4aa3cd95f..2a33509d196 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index a83cbc24972..9a1fc4949e7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt index 7ccbb9a2694..d6c0ad0d151 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt index 0733557f70d..66cdc51c355 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt index 71c2e77e7ff..517f2505bdb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index 824bd8bbb2f..a609dcdb55f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt index ac9d5be1883..10adc4ce928 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index ed63b8d98d4..ea0ee3028b3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt index d00f7a5b396..6e5d6bc942e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt index 2ca122485d6..c366341d8a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt index 1b69967a59e..ce86e9a44b4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt index 265f13b06bd..d46d79f30de 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt index c74a5868d98..a73b2ad5cb7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt index 8370406e34d..bc7ddcea718 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 554d7531912..891a5b201b2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt index 101719437d4..5ba7cd1a385 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt index d441302523f..66045e97b8f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt index a736e4c03fd..8ba42c66f06 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt index 3f9002792d8..5f5149dc5c2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt index 217f2701f3d..8bdcee57fa2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 1a2338fe077..38aa2bb7199 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt index 7fd64ab47ca..f0f34aa6796 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt @@ -30,6 +30,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -38,6 +42,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -182,6 +190,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 1962f3284f0..6273d7986a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt index 64073b27c24..e6debc89415 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index 73ed7f59394..08c9a100161 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index 2fb47e8a5a6..d6e3f1252cb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index e3ab2b5ab6d..94e9a96b88f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 494e2247fbf..00afd5fc967 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 22e79311d67..18d826fdd13 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index 83f91393647..92e8876f5c4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d211683ae9c..f9b8b1778f4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index 2b9442dee85..4bfcac30602 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index e02b42bdd0e..0762ac1e767 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 60d2a947d87..69ff605469c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 352527ea0f4..e05fb04f0b2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index a1ff2f402a7..1268fbf60b0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt index d1811a28b55..774a1c23255 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index c39df6fa394..42cf70b4bff 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt index 3877cf015a2..67834ac5768 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -30,6 +30,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -38,6 +42,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -182,6 +190,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt index c9f01c56606..211c366d9ea 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt index 3b9306cdfe6..3c293f42da3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer-normalization.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt index 03902ed1de4..962b7891b90 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt @@ -10,10 +10,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -106,6 +114,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index bf98a150184..3bd0b6a6071 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt index 040230d63b3..0d178d92f06 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt index 8d49e7a58a1..da65b262707 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt index 485ae3b16ef..8dcd821c0e1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt index 05050fdbffa..cce7e92fde3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt index 8ae6a0ab43b..390cb52e228 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt index ae8aea28552..4c2ddc2ce12 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 94d2e0e6f6e..e7956c1029f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 91b0b44ea50..014af0a3b63 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 587850f1d6c..7c1a8e8820e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt index ac97ca6e061..52286c20c58 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt index 7c8950ce3fa..3ae00343db3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt index 89fbb32194a..001a8743a91 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt index 9ef978eeb3a..1b9ec5357fe 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt index 19a48d77113..6b26c702b12 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt index 03d5a2195cd..c79315618a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt index c8c5b8326dd..c09439b7028 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt index 84530e067b2..4d9f2c18b80 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt index 4de5b1c20d8..a9ab889dd55 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt index b9fcb027aaf..4ac4c4fe8f6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 5b6bff9dc5f..664b4ec2547 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 3fb3c032a3e..6ab78380d99 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 5387a8e5fc5..5e5aff38b19 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index de2d3eaaab4..7560106c96e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 80e17948612..4b915342922 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 48e0c26b010..f5b25907cc6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "bias_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dropout" mtype: "" @@ -36,6 +40,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -168,6 +176,10 @@ tf_class { name: "use_bias" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt index 97e4b91bfa3..aebd00c9d98 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index 0260221d093..415259bfbcf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index ddad5641e76..59c31486278 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 47e6ba9abfa..760e6592808 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 5379da642ed..bd454ac60ec 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -115,6 +123,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt index 1e070fb36db..e9c00b51415 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index 6d7724bdfe3..407c860c033 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt index d740fc8de3a..d0f7619d1bf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt index de377d9d2eb..7561273844e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt index e2ee7941662..428f43d7523 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt index 8dd967cd3ce..482ecf32e2a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt index 334463a4031..e902ff4a483 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt index ba4e58e3dfb..2312af6ebac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt index 8538d903fba..8937e9de579 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 2b8681ae8cc..795ad24a3ac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt index a2d7d285409..0242b10c649 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt index e24ca0dc01a..425c5e45934 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-random-fourier-features.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt index f34dce7b307..661e1085887 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index ceb38316d11..1b5f2fc7a10 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt index b4662d3c0e9..80d7a618df8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt index 14d43cb08e8..d09d0f85402 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-center-crop.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt index cb7a793f94d..628f76c84a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt index 75a1efc2f15..2c9af8c3b32 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt index e4e24a25b7b..0c8453339af 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-integer-lookup.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt index c66604af334..d8ad9a9f683 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt index 0fa0355b0f2..62e1c15af95 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-preprocessing-layer.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt index 8eca3903616..f5c73972d6a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-contrast.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt index ad813468f53..55280e81038 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt index 15406e778f8..69fb7bea570 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-flip.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt index 8119cb9687f..80d7fda8d0d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-height.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index 7c68f9ef783..131bc4b3efd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index 73866dfcf50..c212112a64f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt index 0d113434d80..de032c6566c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-width.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index 392e2efef39..ab55f40e8fc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index c8fcedd3221..96882e04598 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt index 7efb8d72dcb..3dc758729eb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-resizing.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt index 80da4a3df58..e628cc4d7f7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-string-lookup.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt index 9fc7c410480..c5eab66f364 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt index 96703a02f88..b0d2b891dbc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -112,6 +120,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt index e1b19dc5836..24026cc795b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt index 64edac45aa4..0cd26f8c214 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt index 8130de478cd..1055718babd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt index 7bfcf1f2bdd..29d1597747a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt index 4e37d89dba7..ff9a7004c4e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt index 9bb3b783982..43b95f6d23b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt index 4defb3657ec..270bd27279f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-cosine-similarity.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt index 628363c3347..0c88ab6d8b5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt index 73c30e9a080..b2a8cbe0703 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt index fe3b145999f..7aaa8e69472 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt index e167d1a743c..c581d3ba734 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-k-l-divergence.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt index 83d2c77a759..6f561ab04a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-log-cosh-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt index 90f18e32247..47572bb065e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt index cab559c90cb..b91104ea20b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-absolute-percentage-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt index 091cf58e3e2..5c5e0992d47 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-io-u.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt index faf78e01f38..5376ff099c3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-relative-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt index d4cb6e5b7ad..eb37f65f3a2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt index 32a1cffa204..5a411922095 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-squared-logarithmic-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt index 788f2e316b0..fd85ec63408 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean-tensor.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "count" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -116,6 +124,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt index 76981e1950c..d0bb8dfc079 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt index 65c65f71f6a..c9c15d8485c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-metric.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt index 22ba826b26e..a3676a52d84 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-poisson.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt index c6bd13d4f1e..8a4ad3060e4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision-at-recall.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt index eaf6bbeb78d..bd55beef8aa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt index 7254e26db7f..2f91c365dff 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall-at-precision.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt index 1f0d977ea23..ab1e70387ad 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt index c6f2ea193d4..1bab1fb18a1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-root-mean-squared-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt index 1c383c8dcf0..827f812b6c0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sensitivity-at-specificity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt index 11c694eb234..ad026eee5d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt index c95806164ac..7b2c8aad4f5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt index 01b883278e5..9cb2fb31167 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt index 3d90bfc1ac9..cae343a0b79 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-specificity-at-sensitivity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt index 25858b461b3..d947aa57cc4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-squared-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt index 065036aee12..90c6ec1b640 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sum.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt index 2fac29e2415..561a13f6625 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt index 8fbf24a66ad..262ec6a3ed0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt index 4d1f2fd213f..af3f67971f1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..3bea3a9f8fe --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt @@ -0,0 +1,116 @@ +path: "tensorflow.keras.mixed_precision.LossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "clipnorm" + mtype: "" + } + member { + name: "clipvalue" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } + member { + name: "global_clipnorm" + mtype: "" + } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "inner_optimizer" + mtype: "" + } + member { + name: "iterations" + mtype: "" + } + member { + name: "loss_scale" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'inner_optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_scaled_loss" + argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_unscaled_gradients" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt new file mode 100644 index 00000000000..278d4fe120e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.keras.mixed_precision.Policy" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "compute_dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "variable_dtype" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index 3c016d331de..a68c17e8f24 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -13,10 +14,30 @@ tf_class { name: "clipvalue" mtype: "" } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } member { name: "global_clipnorm" mtype: "" } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "inner_optimizer" + mtype: "" + } member { name: "iterations" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt index e3435a32bef..74b342efdb4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.mixed_precision.experimental.Policy" tf_class { - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "compute_dtype" @@ -14,10 +15,6 @@ tf_class { name: "name" mtype: "" } - member { - name: "should_cast_variables" - mtype: "" - } member { name: "variable_dtype" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt index e8648afb5f7..7ae84fcab43 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt @@ -1,7 +1,23 @@ path: "tensorflow.keras.mixed_precision" tf_module { + member { + name: "LossScaleOptimizer" + mtype: "" + } + member { + name: "Policy" + mtype: "" + } member { name: "experimental" mtype: "" } + member_method { + name: "global_policy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_global_policy" + argspec: "args=[\'policy\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index 15c0ab5abbb..0c837d030c3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -128,6 +136,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -298,7 +310,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 729fdd660ca..4a595abe9db 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -14,6 +14,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "distribute_strategy" mtype: "" @@ -22,6 +26,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -130,6 +138,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" @@ -316,7 +328,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } member_method { name: "save_weights" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt index ac80126aaa3..798dd83194d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt @@ -30,6 +30,6 @@ tf_module { } member_method { name: "save_model" - argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index 2ea4e8f84a6..77d0d2eeb70 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -144,6 +144,10 @@ tf_module { name: "erfc" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "erfcinv" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "erfinv" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt index e9177defa20..b2bc7a0a061 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -112,6 +120,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt index 6b63d833e42..562030170f3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt index 3d4d7f1ca97..c1a5b14483c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt index bfa0226b05d..40bc3d2a97d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt index a1216947733..24b5a3fd7f9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt index 2f67f041b9a..4593381f91d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt index 1c64601ad46..44d35532aef 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt index b2954042bbf..53bde7f9272 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-cosine-similarity.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt index 7bfee8ccd5d..4602dc18e9c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt index 44028896ecf..b29e3910f1a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt index fee3302e540..70a5af53252 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt index e548b5628b5..e9e5c421c8d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-k-l-divergence.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt index 177850aaabf..6fbd6d8861a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-log-cosh-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt index 429d31b967d..602cc6d0005 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt index be99116d61b..9e3aaff1cd3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-absolute-percentage-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt index 861283a1272..886e92c45a6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-io-u.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt index 17604647fad..f5a4cce8130 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-relative-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt index e2ce76c2a1b..7429e27f359 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt index 4cc3f5c26e7..c03ee74e278 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-squared-logarithmic-error.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt index 97ce4b030af..9f8ed91e211 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean-tensor.pbtxt @@ -12,6 +12,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "count" mtype: "" @@ -20,6 +24,10 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -116,6 +124,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt index de1ecd925e0..f79bdf8fd5a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt index 28659b68e57..f0f17cca126 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-metric.pbtxt @@ -11,10 +11,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -107,6 +115,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt index 37440a594da..9a2cd2ab326 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-poisson.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt index 8d1e8854aca..67a6d7f2d83 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision-at-recall.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt index b53e97f9cbb..02a7897e1fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt index 95a9c0b64c1..d631356d886 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall-at-precision.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt index 9f4d76fecba..3faf3a5c803 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt @@ -12,10 +12,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -108,6 +116,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt index 9d5461d1fb7..1bf12651e20 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-root-mean-squared-error.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -110,6 +118,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt index 6799b934475..d96e46918f3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sensitivity-at-specificity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt index 8e930197163..bcc0b52d2f1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt index f5a41280932..8bd7cf79013 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-crossentropy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt index 463730dece5..ceaa33ef2f9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt index 4d1c97a60e0..09141d1b0cf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-specificity-at-sensitivity.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt index 03092595d20..db373528065 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-squared-hinge.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt index d738fa5a12a..477e71e2119 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sum.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt index 1a91cba2f63..f9259504e08 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-top-k-categorical-accuracy.pbtxt @@ -15,10 +15,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -111,6 +119,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt index 2c0c7cbab48..f9609c49b40 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt index beb1e64b5ea..6b49bb35b01 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt @@ -13,10 +13,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -109,6 +117,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt index e0a352e79bf..1aa698465a1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-device-wrapper.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt index d9a3159309d..7c3d535b215 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-dropout-wrapper.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt index cc3e1399eed..790b58ff9c4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-r-n-n-cell-residual-wrapper.pbtxt @@ -14,10 +14,18 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "compute_dtype" + mtype: "" + } member { name: "dtype" mtype: "" } + member { + name: "dtype_policy" + mtype: "" + } member { name: "dynamic" mtype: "" @@ -118,6 +126,10 @@ tf_class { name: "updates" mtype: "" } + member { + name: "variable_dtype" + mtype: "" + } member { name: "variables" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-profiler-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-profiler-options.pbtxt index 0ad95937220..1139c1c5038 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-profiler-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.experimental.-profiler-options.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "delay_ms" + mtype: "" + } member { name: "device_tracer_level" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2efd289c259..96be23b9e50 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3212,6 +3212,10 @@ tf_module { name: "RaggedTensorToVariant" argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "RaggedTensorToVariantGradient" + argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "RandomCrop" argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " diff --git a/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh b/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh index 523daf666d5..047b169d13a 100644 --- a/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh +++ b/tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh @@ -19,11 +19,11 @@ set -e set -x # CPU size -MAC_CPU_MAX_WHL_SIZE=165M +MAC_CPU_MAX_WHL_SIZE=175M LINUX_CPU_MAX_WHL_SIZE=138M WIN_CPU_MAX_WHL_SIZE=113M # GPU size -LINUX_GPU_MAX_WHL_SIZE=380M +LINUX_GPU_MAX_WHL_SIZE=390M WIN_GPU_MAX_WHL_SIZE=252M function run_smoke_test() { diff --git a/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh b/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh deleted file mode 100644 index 9bce4d1020c..00000000000 --- a/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# TODO(mihaimaruseac,hyey,ggadde): Convert to py3 - -set -e - -# Error if we somehow forget to set the path to bazel_wrapper.py -set -u -BAZEL_WRAPPER_PATH=$1 -set +u - -# From this point on, logs can be publicly available -set -x - -function setup_pip () { - install_pip2 - python -m virtualenv tf_build_env --system-site-packages - source tf_build_env/bin/activate - install_macos_pip_deps -} - -function run_build () { - # Run configure. - export TF_NEED_CUDA=0 - export PYTHON_BIN_PATH=$(which python2) - yes "" | $PYTHON_BIN_PATH configure.py - tag_filters="-no_oss,-no_oss_py2,-gpu,-tpu,-benchmark-test,-nomac,-no_mac,-v1only" - - # Get the default test targets for bazel. - source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - - "${BAZEL_WRAPPER_PATH}" \ - test \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" \ - --action_env=PATH \ - --remote_accept_cached=true \ - --spawn_strategy=standalone \ - --remote_local_fallback=false \ - --remote_timeout=600 \ - --strategy=Javac=standalone \ - --strategy=Closure=standalone \ - --genrule_strategy=standalone \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... - - # Copy log to output to be available to GitHub - ls -la "$(bazel info output_base)/java.log" - cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/" -} - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -setup_pip -run_build diff --git a/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh b/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh index e648c488a00..54bdd261fdf 100644 --- a/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh +++ b/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh @@ -52,7 +52,7 @@ function run_build () { --strategy=Javac=standalone \ --strategy=Closure=standalone \ --genrule_strategy=standalone \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... + -- ${DEFAULT_BAZEL_TARGETS} # Copy log to output to be available to GitHub ls -la "$(bazel info output_base)/java.log" diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh index 148ab16de6c..937ab048859 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh @@ -23,7 +23,7 @@ if [[ "$IS_NIGHTLY" -eq 1 ]]; then install_bazelisk # Pick a version of xcode - export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer + export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" # Update the version string to nightly diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh deleted file mode 100644 index dc1491d65d7..00000000000 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" -python3.5 -m virtualenv tf_build_env --system-site-packages -source tf_build_env/bin/activate - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.5 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export TF2_BEHAVIOR=1 -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py35,-v1only,-gpu,-tpu,-benchmark-test" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py35_pip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py35_pip.sh deleted file mode 100644 index 99c2a149394..00000000000 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py35_pip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.5 - -# Export required variables for running pip_new.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' -export TF_BUILD_BOTH_CPU_PACKAGES=1 - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_macos" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35,-gpu,-tpu,-benchmark-test' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh index eb245cb1d04..47dab895d2c 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" python3.6 -m virtualenv tf_build_env --system-site-packages source tf_build_env/bin/activate diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py36_pip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py36_pip.sh index 375a8c705fa..0e944e3469e 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py36_pip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py36_pip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" # Install macos pip dependencies diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh index b9e6c3c9cf0..3b26b0c8156 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" python -m virtualenv tf_build_env --system-site-packages source tf_build_env/bin/activate diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py37_pip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py37_pip.sh index ea6779be698..00dcf7b3f46 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py37_pip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py37_pip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" # Install macos pip dependencies diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh index a90d59ff492..49714f82039 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" python -m virtualenv tf_build_env --system-site-packages source tf_build_env/bin/activate diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py38_pip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py38_pip.sh index f0ef8e89766..b901116f4fc 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py38_pip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py38_pip.sh @@ -20,7 +20,7 @@ source tensorflow/tools/ci_build/release/common.sh install_bazelisk # Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +export DEVELOPER_DIR=/Applications/Xcode_11.3.app/Contents/Developer sudo xcode-select -s "${DEVELOPER_DIR}" # Install macos pip dependencies diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_nonpip.sh deleted file mode 100644 index fee64f0beb1..00000000000 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_nonpip.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -export TF2_BEHAVIOR=1 -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py35,-v1only" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -set +e -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_pip.sh deleted file mode 100644 index bdbb7f15e34..00000000000 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_py35_pip.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_linux" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35,-v1only' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_nonpip.sh deleted file mode 100644 index a8dfd2047ba..00000000000 --- a/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_nonpip.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -export TF2_BEHAVIOR=1 -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35" - -set +e -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --test_lang_filters=py \ - --test_tag_filters=${tag_filters} \ - --build_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_pip.sh deleted file mode 100644 index f178ac0754e..00000000000 --- a/tensorflow/tools/ci_build/rel/ubuntu/gpu_py35_pip.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35' -export TF_BUILD_FLAGS="--config=release_gpu_linux " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_gpu" -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat b/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat deleted file mode 100644 index 4122756ef40..00000000000 --- a/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat +++ /dev/null @@ -1,24 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -if "%IS_NIGHTLY%" == "1" ( - call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" -) else ( - call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" -) diff --git a/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat b/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat deleted file mode 100644 index 7a8eb53d1e1..00000000000 --- a/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat +++ /dev/null @@ -1,26 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -if "%IS_NIGHTLY%" == "1" ( - call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" -) else ( - call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" - bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh -) diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index ec335ad7408..7e837596350 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -103,55 +103,6 @@ function update_bazel_linux { # LINT.ThenChange( # //tensorflow_estimator/google/kokoro/common.sh) -function install_pip2 { - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - sudo python2 get-pip.py -} - -function install_pip3.5 { - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - sudo python3.5 get-pip.py -} - -function install_pip_deps { - SUDO_CMD="" - PIP_CMD="pip" - - while true; do - if [[ -z "${1}" ]]; then - break - fi - if [[ "$1" == "sudo" ]]; then - SUDO_CMD="sudo " - elif [[ "$1" == "pip"* ]]; then - PIP_CMD="$1" - fi - shift - done - - # LINT.IfChange(ubuntu_pip_installations) - # TODO(aselle): Change all these to be --user instead of sudo. - ${SUDO_CMD} ${PIP_CMD} install astunparse==1.6.3 - ${SUDO_CMD} ${PIP_CMD} install keras_preprocessing==1.1.0 --no-deps - "${PIP_CMD}" install numpy==1.16.0 --user - "${PIP_CMD}" install PyYAML==3.13 --user - ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3 - ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0 - ${SUDO_CMD} ${PIP_CMD} install six==1.12.0 - ${SUDO_CMD} ${PIP_CMD} install grpcio - ${SUDO_CMD} ${PIP_CMD} install portpicker - ${SUDO_CMD} ${PIP_CMD} install scipy - ${SUDO_CMD} ${PIP_CMD} install scikit-learn - ${SUDO_CMD} ${PIP_CMD} install typing_extensions - ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly - ${PIP_CMD} install --user --upgrade flatbuffers - ${PIP_CMD} install --user --upgrade attrs - ${PIP_CMD} install --user --upgrade tf-estimator-nightly - ${PIP_CMD} install --user --upgrade "future>=0.17.1" - ${PIP_CMD} install --user --upgrade wrapt - # LINT.ThenChange(:ubuntu_16_pip_installations) -} - function install_ubuntu_16_pip_deps { PIP_CMD="pip" @@ -165,30 +116,44 @@ function install_ubuntu_16_pip_deps { shift done - # LINT.IfChange(ubuntu_16_pip_installations) - "${PIP_CMD}" install astunparse==1.6.3 --user - "${PIP_CMD}" install --user --upgrade attrs - "${PIP_CMD}" install --user --upgrade flatbuffers - "${PIP_CMD}" install keras_preprocessing==1.1.0 --no-deps --user - "${PIP_CMD}" install numpy==1.16.0 --user - "${PIP_CMD}" install --user --upgrade "future>=0.17.1" - "${PIP_CMD}" install gast==0.3.3 --user - "${PIP_CMD}" install h5py==2.10.0 --user - "${PIP_CMD}" install six==1.12.0 --user - "${PIP_CMD}" install grpcio --user - "${PIP_CMD}" install portpicker --user - "${PIP_CMD}" install scipy --user - "${PIP_CMD}" install scikit-learn --user - "${PIP_CMD}" install typing_extensions --user - "${PIP_CMD}" install PyYAML==3.13 --user - # b/156523241 - "${PIP_CMD}" install --force-reinstall --user --upgrade tf-estimator-nightly - "${PIP_CMD}" install --user --upgrade tb-nightly - "${PIP_CMD}" install --user --upgrade wrapt - # LINT.ThenChange(:ubuntu_pip_installations) + # LINT.IfChange(linux_pip_installations) + # To have reproducible builds, these dependencies should be pinned always. + # Prefer pinning to the same version as in setup.py + # First, upgrade pypi wheels + "${PIP_CMD}" install --user --upgrade setuptools pip wheel + # Now, install the deps, as listed in setup.py + "${PIP_CMD}" install --user 'absl-py ~= 0.10' + "${PIP_CMD}" install --user 'astunparse ~= 1.6.3' + "${PIP_CMD}" install --user 'flatbuffers ~= 1.12.0' + "${PIP_CMD}" install --user 'google_pasta ~= 0.2' + "${PIP_CMD}" install --user 'h5py ~= 2.10.0' + "${PIP_CMD}" install --user 'keras_preprocessing ~= 1.1.2' + "${PIP_CMD}" install --user 'numpy ~= 1.19.2' + "${PIP_CMD}" install --user 'opt_einsum ~= 3.3.0' + "${PIP_CMD}" install --user 'protobuf ~= 3.13.0' + "${PIP_CMD}" install --user 'six ~= 1.15.0' + "${PIP_CMD}" install --user 'termcolor ~= 1.1.0' + "${PIP_CMD}" install --user 'typing_extensions ~= 3.7.4' + "${PIP_CMD}" install --user 'wheel ~= 0.35' + "${PIP_CMD}" install --user 'wrapt ~= 1.12.1' + # We need to pin gast dependency exactly + "${PIP_CMD}" install --user 'gast == 0.3.3' + # Finally, install tensorboard and estimator + # Note that here we want the latest version that matches (b/156523241) + "${PIP_CMD}" install --user --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a' + "${PIP_CMD}" install --user --upgrade --force-reinstall 'tensorflow_estimator ~= 2.3.0' + # Test dependencies + "${PIP_CMD}" install --user 'grpcio ~= 1.32.0' + "${PIP_CMD}" install --user 'portpicker ~= 1.3.1' + "${PIP_CMD}" install --user 'scipy ~= 1.5.2' + # LINT.ThenChange(:mac_pip_installations) + # Need to be addressed later. Unblocking 2.4 branchcut + "${PIP_CMD}" install --user 'PyYAML ~= 5.3.1' } function install_macos_pip_deps { + # TODO(mihaimaruseac): Remove need for sudo, then this can be merged with + # above (probably needs to convert to venv too). SUDO_CMD="" PIP_CMD="pip" @@ -207,30 +172,37 @@ function install_macos_pip_deps { shift done - # High Sierra pip for Python2.7 installs don't work as expected. - if [[ "${PIP_CMD}" == "pip" ]]; then - PIP_CMD="python -m pip" - SUDO_CMD="sudo -H " - fi - - # TODO(aselle): Change all these to be --user instead of sudo. - ${SUDO_CMD} ${PIP_CMD} install --upgrade setuptools==39.1.0 - ${SUDO_CMD} ${PIP_CMD} install keras_preprocessing==1.1.0 --no-deps - ${SUDO_CMD} ${PIP_CMD} install --upgrade mock portpicker scipy grpcio - ${SUDO_CMD} ${PIP_CMD} install six==1.12.0 - ${SUDO_CMD} ${PIP_CMD} install scikit-learn - ${SUDO_CMD} ${PIP_CMD} install numpy==1.16.0 - ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3 - ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0 - ${SUDO_CMD} ${PIP_CMD} install typing_extensions - ${SUDO_CMD} ${PIP_CMD} install --upgrade grpcio - ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly - ${PIP_CMD} install --user --upgrade flatbuffers - ${PIP_CMD} install --user --upgrade attrs - # b/156523241 - ${PIP_CMD} install --force-reinstall --user --upgrade tf-estimator-nightly - ${PIP_CMD} install --user --upgrade wrapt - ${PIP_CMD} install --user --upgrade "future>=0.17.1" + # LINT.IfChange(mac_pip_installations) + # To have reproducible builds, these dependencies should be pinned always. + # Prefer pinning to the same version as in setup.py + # First, upgrade pypi wheels + ${PIP_CMD} install --user --upgrade setuptools pip wheel + # Now, install the deps, as listed in setup.py + ${PIP_CMD} install --user 'absl-py ~= 0.10' + ${PIP_CMD} install --user 'astunparse ~= 1.6.3' + ${PIP_CMD} install --user 'flatbuffers ~= 1.12.0' + ${PIP_CMD} install --user 'google_pasta ~= 0.2' + ${PIP_CMD} install --user 'h5py ~= 2.10.0' + ${PIP_CMD} install --user 'keras_preprocessing ~= 1.1.2' + ${PIP_CMD} install --user 'numpy ~= 1.19.2' + ${PIP_CMD} install --user 'opt_einsum ~= 3.3.0' + ${PIP_CMD} install --user 'protobuf ~= 3.13.0' + ${PIP_CMD} install --user 'six ~= 1.15.0' + ${PIP_CMD} install --user 'termcolor ~= 1.1.0' + ${PIP_CMD} install --user 'typing_extensions ~= 3.7.4' + ${PIP_CMD} install --user 'wheel ~= 0.35' + ${PIP_CMD} install --user 'wrapt ~= 1.12.1' + # We need to pin gast dependency exactly + ${PIP_CMD} install --user 'gast == 0.3.3' + # Finally, install tensorboard and estimator + # Note that here we want the latest version that matches (b/156523241) + ${PIP_CMD} install --user --upgrade --force-reinstall 'tb-nightly ~= 2.4.0.a' + ${PIP_CMD} install --user --upgrade --force-reinstall 'tensorflow_estimator ~= 2.3.0' + # Test dependencies + ${PIP_CMD} install --user 'grpcio ~= 1.32.0' + ${PIP_CMD} install --user 'portpicker ~= 1.3.1' + ${PIP_CMD} install --user 'scipy ~= 1.5.2' + # LINT.ThenChange(:linux_pip_installations) } function maybe_skip_v1 { diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat index 41c536a58ff..a49140c7574 100644 --- a/tensorflow/tools/ci_build/release/common_win.bat +++ b/tensorflow/tools/ci_build/release/common_win.bat @@ -21,43 +21,37 @@ IF NOT DEFINED PYTHON_DIRECTORY ( SET PYTHON_DIRECTORY=Python36 ) SET PY_EXE=C:\%PYTHON_DIRECTORY%\python.exe -SET PIP_EXE=C:\%PYTHON_DIRECTORY%\Scripts\pip.exe SET PATH=%PATH%;C:\%PYTHON_DIRECTORY% -@REM TODO(amitpatankar): Make an image with these packages and remove this. - -%PIP_EXE% install flatbuffers --upgrade --no-deps -%PIP_EXE% install setuptools --upgrade -%PIP_EXE% install future>=0.17.1 --no-deps -%PIP_EXE% install --ignore-installed --force-reinstall --upgrade tf-estimator-nightly --no-deps -%PIP_EXE% install tb-nightly --no-deps -%PIP_EXE% install numpy==1.16.0 --upgrade --no-deps -%PIP_EXE% install opt_einsum --upgrade -%PIP_EXE% install pandas --upgrade --no-deps -%PIP_EXE% install protobuf --upgrade --no-deps -%PIP_EXE% install keras_preprocessing==1.1.0 --upgrade --no-deps -%PIP_EXE% install wrapt --upgrade --no-deps -%PIP_EXE% install absl-py==0.9.0 - -IF "%PYTHON_DIRECTORY%"=="Python37" ( - %PIP_EXE% install colorama==0.3.9 - %PIP_EXE% install cycler==0.10.0 - %PIP_EXE% install jedi==0.11.1 - %PIP_EXE% install oauth2client==4.1.2 - %PIP_EXE% install portpicker==1.2.0 - %PIP_EXE% install parso==0.1.1 - %PIP_EXE% install protobuf==3.8.0 - %PIP_EXE% install scikit-learn==0.19.2 - %PIP_EXE% install scipy==1.1.0 - %PIP_EXE% install termcolor==1.1.0 -) - -@REM TODO(amitpatankar): this is just a quick fix so that windows build doesn't -@REM break with gast upgrade to 0.3.3. Need to figure out the right way to -@REM handle this case. -%PIP_EXE% install gast==0.3.3 -%PIP_EXE% install astunparse==1.6.3 -%PIP_EXE% install typing_extensions +@REM To have reproducible builds, these dependencies should be pinned always. +@REM Prefer pinning to the same version as in setup.py +@REM First, upgrade pypi wheels +%PY_EXE% -m pip install --upgrade setuptools pip wheel +@REM Now, install the deps, as listed in setup.py +%PY_EXE% -m pip install "absl-py ~= 0.10" +%PY_EXE% -m pip install "astunparse ~= 1.6.3" +%PY_EXE% -m pip install "flatbuffers ~= 1.12.0" +%PY_EXE% -m pip install "google_pasta ~= 0.2" +%PY_EXE% -m pip install "h5py ~= 2.10.0" +%PY_EXE% -m pip install "keras_preprocessing ~= 1.1.2" +%PY_EXE% -m pip install "numpy ~= 1.19.2" +%PY_EXE% -m pip install "opt_einsum ~= 3.3.0" +%PY_EXE% -m pip install "protobuf ~= 3.13.0" +%PY_EXE% -m pip install "six ~= 1.15.0" +%PY_EXE% -m pip install "termcolor ~= 1.1.0" +%PY_EXE% -m pip install "typing_extensions ~= 3.7.4" +%PY_EXE% -m pip install "wheel ~= 0.35" +%PY_EXE% -m pip install "wrapt ~= 1.12.1" +@REM We need to pin gast dependency exactly +%PY_EXE% -m pip install "gast == 0.3.3" +@REM Finally, install tensorboard and estimator +@REM Note that here we want the latest version that matches (b/156523241) +%PY_EXE% -m pip install --upgrade --force-reinstall "tb-nightly ~= 2.4.0.a" +%PY_EXE% -m pip install --upgrade --force-reinstall "tensorflow_estimator ~= 2.3.0" +@REM Test dependencies +%PY_EXE% -m pip install "grpcio ~= 1.32.0" +%PY_EXE% -m pip install "portpicker ~= 1.3.1" +%PY_EXE% -m pip install "scipy ~= 1.5.2" :: Set cuda related environment variables. If we are not using CUDA, these are not used. IF NOT DEFINED TF_CUDA_VERSION ( diff --git a/tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/build.sh b/tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/build.sh deleted file mode 100644 index 3dfab5a2aaa..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_libtensorflow/build.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -echo "chmod go+w lib_package/*" >> tensorflow/tools/ci_build/linux/libtensorflow.sh -echo "bazel clean --expunge" >> tensorflow/tools/ci_build/linux/libtensorflow.sh - -# Install latest bazel -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Update the version string to nightly -./tensorflow/tools/ci_build/update_version.py --nightly - -tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh - -# Copy the nightly version update script -cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nightly_release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nightly_release.sh deleted file mode 100644 index 7da3b0ea9be..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nightly_release.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -sudo pip3.5 install --upgrade pip - -install_macos_pip_deps sudo pip3.5 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -sudo pip install twine - -./tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Copy and rename to tf_nightly -for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do - copy_to_new_project_name "${f}" tf_nightly -done - -# Upload the built packages to pypi. -for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${f}" - twine upload -r pypi-warehouse "${f}" || echo - else - echo "Basic PIP test FAILED, will not upload ${f} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip.sh deleted file mode 100644 index dc1491d65d7..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" -python3.5 -m virtualenv tf_build_env --system-site-packages -source tf_build_env/bin/activate - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.5 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export TF2_BEHAVIOR=1 -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py35,-v1only,-gpu,-tpu,-benchmark-test" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip_v1.sh deleted file mode 100644 index f045e7103e0..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/nonpip_v1.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -sudo xcode-select --switch /Applications/Xcode_9.2.app/Contents/Developer - -# Install pip dependencies -install_pip3.5 -install_macos_pip_deps sudo pip3.5 - -export PATH=$PATH:/usr/local/bin - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py35" - -# Run tests -bazel test --test_output=errors --config=opt \ - --incompatible_depset_union=false \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - //tensorflow/... \ - -//tensorflow/compiler/... \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh deleted file mode 100644 index 99c2a149394..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.5 - -# Export required variables for running pip_new.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' -export TF_BUILD_BOTH_CPU_PACKAGES=1 - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_macos" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35,-gpu,-tpu,-benchmark-test' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh deleted file mode 100644 index dcbd5b504c8..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Install pip dependencies -sudo pip3.5 install --upgrade pip -install_macos_pip_deps sudo pip3.5 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -# Export required variables for running pip.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh - -# Copy and rename to tensorflow_cpu -for WHL_PATH in $(ls ${TF_ARTIFACTS_DIR}/tensorflow/${TF_PIP_TEST_ROOT}/whl/tensorflow*.whl); do - copy_to_new_project_name "${WHL_PATH}" tensorflow_cpu -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/release.sh deleted file mode 100644 index 8ee43fb1b2f..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/release.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -sudo pip install twine - -./tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=opt tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag - -# Also upload the python 3.5 package as python 3.3 and 3.4 packages. -FILENAME="$(ls pip_pkg/tf_nightly-*dev*-macosx_*.whl)" -tensorflow/tools/ci_build/copy_binary.py --filename "${FILENAME}" --new_py_ver 33 -tensorflow/tools/ci_build/copy_binary.py --filename "${FILENAME}" --new_py_ver 34 - -for f in $(ls pip_pkg/tf_nightly-*dev*macosx*.whl); do - echo "Uploading package: ${f}" - twine upload -r pypi-warehouse "${f}" || echo -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nightly_release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nightly_release.sh deleted file mode 100644 index 33e1491dd86..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nightly_release.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -install_macos_pip_deps sudo pip3.6 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -sudo pip install twine - -./tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Copy and rename to tf_nightly -for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do - copy_to_new_project_name "${f}" tf_nightly -done - -# Upload the built packages to pypi. -for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${f}" - twine upload -r pypi-warehouse "${f}" || echo - else - echo "Basic PIP test FAILED, will not upload ${f} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip.sh deleted file mode 100644 index eb245cb1d04..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" -python3.6 -m virtualenv tf_build_env --system-site-packages -source tf_build_env/bin/activate - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.6 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export TF2_BEHAVIOR=1 -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py36,-v1only,-gpu,-tpu,-benchmark-test" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip_v1.sh deleted file mode 100644 index 2f639d7fc6b..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/nonpip_v1.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -sudo xcode-select --switch /Applications/Xcode_9.2.app/Contents/Developer - -# Install pip dependencies -install_macos_pip_deps sudo pip3.6 - -export PATH=$PATH:/usr/local/bin - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --incompatible_depset_union=false \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh deleted file mode 100644 index 375a8c705fa..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.6 - -# Export required variables for running pip_new.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.6' -export TF_BUILD_BOTH_CPU_PACKAGES=1 - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_macos" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35,-v1only,-gpu,-tpu,-benchmark-test' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh deleted file mode 100644 index 3d04cf1d9ba..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Install pip dependencies -install_macos_pip_deps sudo pip3.6 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -# Export required variables for running pip.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.6' - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh - -# Copy and rename to tensorflow_cpu -for WHL_PATH in $(ls "${TF_ARTIFACTS_DIR}"/tensorflow/${TF_PIP_TEST_ROOT}/whl/tensorflow*.whl); do - copy_to_new_project_name "${WHL_PATH}" tensorflow_cpu -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nightly_release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nightly_release.sh deleted file mode 100644 index 631aea318bd..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nightly_release.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -install_macos_pip_deps sudo pip3.7 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -sudo pip install twine - -./tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Copy and rename to tf_nightly -for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do - copy_to_new_project_name "${f}" tf_nightly -done - -# Upload the built packages to pypi. -for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${f}" - twine upload -r pypi-warehouse "${f}" || echo - else - echo "Basic PIP test FAILED, will not upload ${f} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip.sh deleted file mode 100644 index b9e6c3c9cf0..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" -python -m virtualenv tf_build_env --system-site-packages -source tf_build_env/bin/activate - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.7 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export TF2_BEHAVIOR=1 -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac$(maybe_skip_v1),-gpu,-tpu,-benchmark-test" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip_v1.sh deleted file mode 100644 index a05cd81d74f..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/nonpip_v1.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -sudo xcode-select --switch /Applications/Xcode_9.2.app/Contents/Developer - -# Install pip dependencies -install_macos_pip_deps sudo pip3.7 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac" - -# Run tests -bazel test --test_output=errors --config=opt \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh deleted file mode 100644 index ea6779be698..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.7 - -# Export required variables for running pip_new.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.7' -export TF_BUILD_BOTH_CPU_PACKAGES=1 - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_macos" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py37,-v1only,-gpu,-tpu,-benchmark-test' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh deleted file mode 100644 index c3840aa2dc8..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Install pip dependencies -install_macos_pip_deps sudo pip3.7 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -# Export required variables for running pip.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.7' - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh - -# Copy and rename to tensorflow_cpu -for WHL_PATH in $(ls ${TF_PIP_TEST_ROOT}/whl/tensorflow*.whl); do - copy_to_new_project_name "${WHL_PATH}" tensorflow_cpu -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/release.sh deleted file mode 100644 index 7465838abb9..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/release.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -install_macos_pip_deps sudo pip3.7 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=opt tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg - -# Copy and rename to tensorflow_cpu -for WHL_PATH in $(ls "${TF_ARTIFACTS_DIR}"/github/tensorflow/pip_pkg/tensorflow*.whl); do - copy_to_new_project_name "${WHL_PATH}" tensorflow_cpu -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nightly_release.sh b/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nightly_release.sh deleted file mode 100644 index 5ffef89188c..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nightly_release.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -install_macos_pip_deps sudo pip3.8 - -# For python3 path on Mac -export PATH=$PATH:/usr/local/bin - -sudo pip install twine - -./tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.8) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package -mkdir pip_pkg -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Copy and rename to tf_nightly -for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do - copy_to_new_project_name "${f}" tf_nightly -done - -# Upload the built packages to pypi. -for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${f}" - twine upload -r pypi-warehouse "${f}" || echo - else - echo "Basic PIP test FAILED, will not upload ${f} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nonpip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nonpip.sh deleted file mode 100644 index a90d59ff492..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/nonpip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" -python -m virtualenv tf_build_env --system-site-packages -source tf_build_env/bin/activate - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.8 - -# Run configure. -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export TF2_BEHAVIOR=1 -export PYTHON_BIN_PATH=$(which python3.8) -yes "" | "$PYTHON_BIN_PATH" configure.py - -tag_filters="-no_oss,-oss_serial,-nomac,-no_mac$(maybe_skip_v1),-gpu,-tpu,-benchmark-test" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} \ - -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh deleted file mode 100644 index f0ef8e89766..00000000000 --- a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk - -# Pick a more recent version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" - -# Install macos pip dependencies -install_macos_pip_deps sudo pip3.8 - -# Export required variables for running pip_new.sh -export OS_TYPE="MACOS" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.8' -export TF_BUILD_BOTH_CPU_PACKAGES=1 - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_macos" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="//tensorflow/python/..." -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py38,-v1only,-gpu,-tpu,-benchmark-test' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nightly_release.sh deleted file mode 100644 index 4af77739c55..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nightly_release.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip.sh deleted file mode 100644 index fee64f0beb1..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -export TF2_BEHAVIOR=1 -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py35,-v1only" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -set +e -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip_v1.sh deleted file mode 100644 index 4231891fbdb..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/nonpip_v1.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py35" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh deleted file mode 100644 index bdbb7f15e34..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_linux" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35,-v1only' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh deleted file mode 100644 index 1e2665f4120..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nightly_release.sh deleted file mode 100644 index 9cca17e5517..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nightly_release.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip.sh deleted file mode 100644 index 6b05141f00f..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -export TF2_BEHAVIOR=1 -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py36,-v1only" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -set +e -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip_v1.sh deleted file mode 100644 index 38d03c8868c..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/nonpip_v1.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py36" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh deleted file mode 100644 index 6277291043c..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.6' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_linux" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36,-v1only' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh deleted file mode 100644 index c4d78dc3fe5..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.6' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nightly_release.sh deleted file mode 100644 index 29fe8f4c351..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nightly_release.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip.sh deleted file mode 100644 index db0c6056b6c..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -export TF2_BEHAVIOR=1 -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py37,-v1only" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -set +e -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip_v1.sh deleted file mode 100644 index 098155aa026..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/nonpip_v1.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py37" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh deleted file mode 100644 index ff88ae46f39..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.7' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_linux" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37,-v1only' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh deleted file mode 100644 index 2208327388f..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.7' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=opt --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going" -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nightly_release.sh deleted file mode 100644 index 442d6a4cc76..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nightly_release.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.8) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nonpip.sh deleted file mode 100644 index 36da30167d0..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/nonpip.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=0 -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.8) -export TF2_BEHAVIOR=1 -yes "" | "$PYTHON_BIN_PATH" configure.py -tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py38,-v1only" - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Run tests -set +e -bazel test --test_output=errors --config=opt --test_lang_filters=py \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --build_tag_filters="${tag_filters}" \ - --test_tag_filters="${tag_filters}" -- \ - ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh deleted file mode 100644 index 52872cfd0a6..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="CPU" -export TF_PYTHON_VERSION='python3.8' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_BUILD_FLAGS="--config=release_cpu_linux" -export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1" -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py38,-v1only' -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_cpu" -export TF_PIP_TEST_ROOT="pip_test" - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_pip_on_cpu/build.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_pip_on_cpu/build.sh deleted file mode 100755 index 22ca5b7b567..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_pip_on_cpu/build.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 -# Update Bazel to the desired version -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -######################## -## Build GPU pip package -######################## -bazel build --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - tensorflow/tools/pip_package:build_pip_package - -# Set TF nightly flag so we get the proper version of estimator -if [[ "$IS_NIGHTLY" == 1 ]]; then - NIGHTLY_FLAG="--nightly_flag" -fi - -PIP_WHL_DIR=whl -mkdir -p ${PIP_WHL_DIR} -PIP_WHL_DIR=$(readlink -f ${PIP_WHL_DIR}) # Get absolute path -bazel-bin/tensorflow/tools/pip_package/build_pip_package "${PIP_WHL_DIR}" "${NIGHTLY_FLAG}" -WHL_PATH=$(ls "${PIP_WHL_DIR}"/*.whl) - -cp "${WHL_PATH}" "$(pwd)"/. -chmod +x tensorflow/tools/ci_build/builds/docker_cpu_pip.sh -docker run -e "BAZEL_VERSION=${BAZEL_VERSION}" -e "CI_BUILD_USER=$(id -u -n)" -e "CI_BUILD_UID=$(id -u)" -e "CI_BUILD_GROUP=$(id -g -n)" -e "CI_BUILD_GID=$(id -g)" -e "CI_BUILD_HOME=/bazel_pip" -v "$(pwd)":/bazel_pip tensorflow/tensorflow:devel "./bazel_pip/tensorflow/tools/ci_build/builds/with_the_same_user" "./bazel_pip/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nightly_release.sh deleted file mode 100644 index aac88b57fa7..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nightly_release.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - - # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so - WHL_PATH=${AUDITED_WHL_NAME} - cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" - echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip.sh deleted file mode 100644 index a8dfd2047ba..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -export TF2_BEHAVIOR=1 -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35" - -set +e -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --test_lang_filters=py \ - --test_tag_filters=${tag_filters} \ - --build_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip_v1.sh deleted file mode 100644 index e0e69504f26..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/nonpip_v1.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.5) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35" - -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --test_lang_filters=py \ - --build_tag_filters=${tag_filters} \ - --test_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh deleted file mode 100644 index f178ac0754e..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35' -export TF_BUILD_FLAGS="--config=release_gpu_linux " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME="tensorflow_gpu" -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh deleted file mode 100644 index 6c83621269e..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.5 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.5' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35' -export TF_BUILD_FLAGS="--config=opt --config=cuda --distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=${PROJECT_NAME} -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nightly_release.sh deleted file mode 100644 index 600b4b0be8e..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nightly_release.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - - # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so - WHL_PATH=${AUDITED_WHL_NAME} - cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" - echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip.sh deleted file mode 100644 index c52acec7784..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -export TF2_BEHAVIOR=1 -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36" - -set +e -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --test_lang_filters=py \ - --test_tag_filters=${tag_filters} \ - --build_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip_v1.sh deleted file mode 100644 index 1da93811d43..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/nonpip_v1.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.6) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36" - -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --test_lang_filters=py \ - --build_tag_filters=${tag_filters} \ - --test_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh deleted file mode 100644 index 9bc559a01ab..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.6' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36' -export TF_BUILD_FLAGS="--config=release_gpu_linux " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=="tensorflow_gpu" -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh deleted file mode 100644 index e3da69ebc32..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.6 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.6' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36' -export TF_BUILD_FLAGS="--config=opt --config=cuda --distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=${PROJECT_NAME} -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nightly_release.sh deleted file mode 100644 index a9e51461715..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nightly_release.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - - # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so - WHL_PATH=${AUDITED_WHL_NAME} - cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" - echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip.sh deleted file mode 100644 index bf5fabba741..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 -# Update bazel -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -export TF2_BEHAVIOR=1 -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37" - -set +e -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --test_lang_filters=py \ - --build_tag_filters=${tag_filters} \ - --test_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip_v1.sh deleted file mode 100644 index a620e3c92d2..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/nonpip_v1.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.7) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37" - -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \ - --linkopt=-lrt \ - --test_lang_filters=py \ - --build_tag_filters=${tag_filters} \ - --test_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh deleted file mode 100644 index 71d6f3e6401..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 -# Update bazel -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.7' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37' -export TF_BUILD_FLAGS="--config=release_gpu_linux " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=="tensorflow_gpu" -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh deleted file mode 100644 index a0fb0c40001..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.7 - -install_bazelisk - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.7' - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=10 -export TF_CUDNN_VERSION=7 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37' -export TF_BUILD_FLAGS="--config=opt --config=cuda --distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION --action_env=TF_CUDNN_VERSION \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="//tensorflow/python/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=${PROJECT_NAME} -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nightly_release.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nightly_release.sh deleted file mode 100644 index 0b8fd1380f2..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nightly_release.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 - -pip3.7 install --upgrade auditwheel --user - -update_bazel_linux - -python2.7 tensorflow/tools/ci_build/update_version.py --nightly - -# Run configure. -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.8) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Build the pip package -bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package - -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag -./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag - -# Upload the built packages to pypi. -for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do - - WHL_DIR=$(dirname "${WHL_PATH}") - WHL_BASE_NAME=$(basename "${WHL_PATH}") - AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") - - # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so - WHL_PATH=${AUDITED_WHL_NAME} - cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" - echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" - - # test the whl pip package - chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh - ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} - RETVAL=$? - - # Upload the PIP package if whl test passes. - if [ ${RETVAL} -eq 0 ]; then - echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" - twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" - else - echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" - return 1 - fi -done diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nonpip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nonpip.sh deleted file mode 100644 index 5f29daf36e0..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/nonpip.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 -# Update bazel -update_bazel_linux - -# Run configure. -export TF_NEED_GCP=1 -export TF_NEED_HDFS=1 -export TF_NEED_S3=1 -export TF_NEED_CUDA=1 -export TF_CUDA_VERSION=11 -export TF_CUDNN_VERSION=8 -export TF_NEED_TENSORRT=1 -export TENSORRT_INSTALL_PATH=/usr/local/tensorrt -export CC_OPT_FLAGS='-mavx' -export PYTHON_BIN_PATH=$(which python3.8) -export TF2_BEHAVIOR=1 -export PROJECT_NAME="tensorflow_gpu" -export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib" -export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_70,sm_75,compute_80 - -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38" - -test +e -bazel test --config=cuda --config=opt \ - --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \ - --linkopt=-lrt \ - --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ - --test_lang_filters=py \ - --build_tag_filters=${tag_filters} \ - --test_tag_filters=${tag_filters} \ - --test_timeout="300,450,1200,3600" --local_test_jobs=4 \ - --test_output=errors --verbose_failures=true --keep_going \ - --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ - -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh deleted file mode 100644 index f49b77bae70..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -install_ubuntu_16_pip_deps pip3.8 -# Update bazel -update_bazel_linux - -# Export required variables for running pip.sh -export OS_TYPE="UBUNTU" -export CONTAINER_TYPE="GPU" -export TF_PYTHON_VERSION='python3.8' - -# Run configure. -export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION}) -yes "" | "$PYTHON_BIN_PATH" configure.py - -# Get the default test targets for bazel. -source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh - -# Export optional variables for running pip.sh -export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38' -export TF_BUILD_FLAGS="--config=release_gpu_linux " -export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \ ---distinct_host_configuration=false \ ---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \ ---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \ ---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \ ---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " -export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " -export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. -export TF_PROJECT_NAME=="tensorflow_gpu" -export TF_PIP_TEST_ROOT="pip_test" - -# To build both tensorflow and tensorflow-gpu pip packages -export TF_BUILD_BOTH_GPU_PACKAGES=1 - -./tensorflow/tools/ci_build/builds/pip_new.sh diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/cpu/build.sh b/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/cpu/build.sh deleted file mode 100644 index 1504688dcbc..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/cpu/build.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e - -# Source the external common scripts. -source tensorflow/tools/ci_build/release/common.sh - - -# Install latest bazel -install_bazelisk -which bazel - -# Install realpath -sudo apt-get install realpath - -# Update the version string to nightly -if [ -n "${IS_NIGHTLY_BUILD}" ]; then - ./tensorflow/tools/ci_build/update_version.py --nightly -fi - -./tensorflow/tools/ci_build/linux/libtensorflow.sh - -# Copy the nightly version update script -if [ -n "${IS_NIGHTLY_BUILD}" ]; then - cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package -fi - -# Upload to go/tf-sizetracker -python3 ./tensorflow/tools/ci_build/sizetrack_helper.py \ - --team tensorflow_libtensorflow \ - --artifact_id ubuntu_cpu_nightly \ - --upload \ - --artifact "$(find lib_package -iname "libtensorflow*.tar.gz" -not -iname "*jni*" | head -n 1)" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/gpu/build.sh b/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/gpu/build.sh deleted file mode 100644 index d294311d1ff..00000000000 --- a/tensorflow/tools/ci_build/release/ubuntu_16/libtensorflow/gpu/build.sh +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e - -# Source the external common scripts. -source tensorflow/tools/ci_build/release/common.sh - - -# Install latest bazel -install_bazelisk -which bazel - -# Install realpath -sudo apt-get install realpath - -export TF_NEED_CUDA=1 - -# Update the version string to nightly -if [ -n "${IS_NIGHTLY_BUILD}" ]; then - ./tensorflow/tools/ci_build/update_version.py --nightly -fi - -./tensorflow/tools/ci_build/linux/libtensorflow.sh - -# Copy the nightly version update script -if [ -n "${IS_NIGHTLY_BUILD}" ]; then - cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package -fi diff --git a/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/nightly.bat b/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/nightly.bat deleted file mode 100644 index dcc03e784db..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\bazel\run_libtensorflow.bat || exit /b - -copy lib_package %TF_ARTIFACTS_DIR%\lib_package diff --git a/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/release.bat b/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/release.bat deleted file mode 100644 index 67941234b15..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_libtensorflow/release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\bazel\run_libtensorflow.bat || exit /b 1 - -copy lib_package %TF_ARTIFACTS_DIR%\lib_package diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly.bat deleted file mode 100644 index 979a30e046c..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly.bat +++ /dev/null @@ -1,22 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= -echo on -setlocal enableextensions enabledelayedexpansion - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly_release.bat deleted file mode 100644 index 6ed1088893f..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release.bat deleted file mode 100644 index 175917d7cad..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_v1.bat deleted file mode 100644 index e0f0bfeae7b..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py35_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build - -for %%a in ("%~dp0.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly.bat deleted file mode 100644 index fd1854603f5..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly_release.bat deleted file mode 100644 index 3af98dddeae..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release.bat deleted file mode 100644 index 85b75053eff..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release_v1.bat deleted file mode 100644 index 44483213724..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py36_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build - -for %%a in ("%~dp0.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly.bat deleted file mode 100644 index 69b9449b0c3..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly_release.bat deleted file mode 100644 index 850c21ee962..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release.bat deleted file mode 100644 index d8a6673ba4c..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release_v1.bat deleted file mode 100644 index ac549eca53e..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py37_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly.bat deleted file mode 100644 index 0d5b3a7fff8..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly_release.bat deleted file mode 100644 index 2456b1e26bb..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/release.bat b/tensorflow/tools/ci_build/release/windows/cpu_py38_full/release.bat deleted file mode 100644 index 86adcda0bb9..00000000000 --- a/tensorflow/tools/ci_build/release/windows/cpu_py38_full/release.bat +++ /dev/null @@ -1,21 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" - diff --git a/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/nightly.bat b/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/nightly.bat deleted file mode 100644 index 8ab78bef3ca..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\bazel\run_libtensorflow.bat || exit /b - -copy lib_package %TF_ARTIFACTS_DIR%\lib_package diff --git a/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/release.bat b/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/release.bat deleted file mode 100644 index 8ab78bef3ca..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_libtensorflow/release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\bazel\run_libtensorflow.bat || exit /b - -copy lib_package %TF_ARTIFACTS_DIR%\lib_package diff --git a/tensorflow/tools/ci_build/release/windows/gpu_pip_on_cpu/build.bat b/tensorflow/tools/ci_build/release/windows/gpu_pip_on_cpu/build.bat deleted file mode 100644 index 213de532069..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_pip_on_cpu/build.bat +++ /dev/null @@ -1,21 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\integration\gpu_pip_on_cpu\run.bat - diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly.bat deleted file mode 100644 index ba8dee59853..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly_release.bat deleted file mode 100644 index 43e6414a74b..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release.bat deleted file mode 100644 index 86c118b2f83..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_v1.bat deleted file mode 100644 index 55e4e4f5782..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py35_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python35 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly.bat deleted file mode 100644 index 9624ca5f5b2..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly_release.bat deleted file mode 100644 index 15ec83c054e..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release.bat deleted file mode 100644 index cc4f84afbee..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release_v1.bat deleted file mode 100644 index a66ca900e47..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py36_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python36 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly.bat deleted file mode 100644 index c6141c42916..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly_release.bat deleted file mode 100644 index 1eb65d8a284..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release.bat deleted file mode 100644 index 5fa798e3eb8..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release_v1.bat b/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release_v1.bat deleted file mode 100644 index 059e28134c8..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py37_full/release_v1.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python37 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly.bat b/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly.bat deleted file mode 100644 index dcbed63089e..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly_release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly_release.bat deleted file mode 100644 index 670793340e8..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/nightly_release.bat +++ /dev/null @@ -1,20 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release.bat b/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release.bat deleted file mode 100644 index fa1fc131145..00000000000 --- a/tensorflow/tools/ci_build/release/windows/gpu_py38_full/release.bat +++ /dev/null @@ -1,23 +0,0 @@ -:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. -:: -:: Licensed under the Apache License, Version 2.0 (the "License"); -:: you may not use this file except in compliance with the License. -:: You may obtain a copy of the License at -:: -:: http://www.apache.org/licenses/LICENSE-2.0 -:: -:: Unless required by applicable law or agreed to in writing, software -:: distributed under the License is distributed on an "AS IS" BASIS, -:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -:: See the License for the specific language governing permissions and -:: limitations under the License. -:: ============================================================================= - -SET PYTHON_DIRECTORY=Python38 - -CALL tensorflow\tools\ci_build\release\common_win.bat - -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh diff --git a/tensorflow/tools/ci_build/release/windows/upload_nightly_pip/upload.sh b/tensorflow/tools/ci_build/release/windows/upload_nightly_pip/upload.sh deleted file mode 100644 index 609c316cca7..00000000000 --- a/tensorflow/tools/ci_build/release/windows/upload_nightly_pip/upload.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -set -x - -source tensorflow/tools/ci_build/release/common.sh - -sudo pip install --upgrade twine - -# Copy and rename to tf_nightly -for f in $(ls "${TF_FILE_DIR}"/tf_nightly_gpu*dev*cp3*-cp3*-win_amd64.whl); do - copy_to_new_project_name "${f}" tf_nightly -done - -# Upload the built packages to pypi. -for f in $(ls "${TF_FILE_DIR}"/tf_nightly*dev*cp3*-cp3*-win_amd64.whl); do - twine upload -r pypi-warehouse "$f" || echo -done diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index e2aa438cb1a..b925f6b9c36 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -770,6 +770,12 @@ renames = { 'tf.linalg.matrix_transpose', 'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve', + 'tf.mixed_precision.DynamicLossScale': + 'tf.compat.v1.mixed_precision.DynamicLossScale', + 'tf.mixed_precision.FixedLossScale': + 'tf.compat.v1.mixed_precision.FixedLossScale', + 'tf.mixed_precision.LossScale': + 'tf.compat.v1.mixed_precision.LossScale', 'tf.metrics.accuracy': 'tf.compat.v1.metrics.accuracy', 'tf.metrics.auc': @@ -838,6 +844,12 @@ renames = { 'tf.compat.v1.metrics.true_positives_at_thresholds', 'tf.min_max_variable_partitioner': 'tf.compat.v1.min_max_variable_partitioner', + 'tf.mixed_precision.MixedPrecisionLossScaleOptimizer': + 'tf.compat.v1.mixed_precision.MixedPrecisionLossScaleOptimizer', + 'tf.mixed_precision.disable_mixed_precision_graph_rewrite': + 'tf.compat.v1.mixed_precision.disable_mixed_precision_graph_rewrite', + 'tf.mixed_precision.enable_mixed_precision_graph_rewrite': + 'tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite', 'tf.mod': 'tf.math.floormod', 'tf.model_variables': diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index d5cb6b6bc38..222a6bb262b 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -393,6 +393,10 @@ tensorflow::tensor_float_32_execution_enabled [get_compiler_ir] # tfe tensorflow::GetCompilerIr +stream_executor::port::internal_statusor::Helper::Crash [tensor_handle] # tfe tensorflow::TensorHandle::Tensor + +[python_api_dispatcher] # python_api_dispatcher +tensorflow::PythonAPIDispatcher diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 1998e47a3ad..6adc0a73610 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -2,6 +2,9 @@ # Doc generator load("//tensorflow:tensorflow.bzl", "py_test") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -129,22 +132,11 @@ py_test( py_library( name = "doc_controls", srcs = ["doc_controls.py"], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) -py_test( - name = "doc_controls_test", - size = "small", - srcs = ["doc_controls_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":doc_controls", - "//tensorflow/python:platform_test", - ], -) - py_test( name = "generate2_test", size = "medium", diff --git a/tensorflow/tools/docs/build_cc_api_headers.py b/tensorflow/tools/docs/build_cc_api_headers.py new file mode 100644 index 00000000000..c0b67429f73 --- /dev/null +++ b/tensorflow/tools/docs/build_cc_api_headers.py @@ -0,0 +1,63 @@ +# Lint as: python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate Java reference docs for TensorFlow.org.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pathlib +import subprocess + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +# These flags are required by infrastructure, not all of them are used. +flags.DEFINE_string('output_dir', None, + ("Use this branch as the root version and don't" + ' create in version directory')) + +# __file__ is the path to this file +DOCS_TOOLS_DIR = pathlib.Path(__file__).resolve().parent +TENSORFLOW_ROOT = DOCS_TOOLS_DIR.parents[2] + + +def build_headers(output_dir): + """Builds the headers files for TF.""" + + # `$ yes | configure` + yes = subprocess.Popen(['yes', ''], stdout=subprocess.PIPE) + configure = subprocess.Popen([TENSORFLOW_ROOT / 'configure'], + stdin=yes.stdout, + cwd=TENSORFLOW_ROOT) + configure.communicate() + + subprocess.check_call(['bazel', 'build', 'tensorflow/cc:cc_ops'], + cwd=TENSORFLOW_ROOT) + subprocess.check_call( + ['cp', '--dereference', '-r', 'bazel-bin', output_dir / 'bazel-genfiles'], + cwd=TENSORFLOW_ROOT) + + +def main(argv): + del argv + build_headers(pathlib.Path(FLAGS.output_dir)) + + +if __name__ == '__main__': + flags.mark_flags_as_required(['output_dir']) + app.run(main) diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py index 27a1d2075e9..6151e7029d2 100644 --- a/tensorflow/tools/docs/doc_controls.py +++ b/tensorflow/tools/docs/doc_controls.py @@ -18,6 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +_DEPRECATED = "_tf_docs_deprecated" + + +def set_deprecated(obj): + """Explicitly tag an object as deprecated for the doc generator.""" + setattr(obj, _DEPRECATED, None) + return obj + + _DO_NOT_DOC = "_tf_docs_do_not_document" @@ -242,81 +251,33 @@ def for_subclass_implementers(obj): do_not_doc_in_subclasses = for_subclass_implementers +_DOC_PRIVATE = "_tf_docs_doc_private" -def should_skip(obj): - """Returns true if docs generation should be skipped for this object. - checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators. +def doc_private(obj): + """A decorator: Generates docs for private methods/functions. + + For example: + + ``` + class Try: + + @doc_controls.doc_private + def _private(self): + ... + ``` + + As a rule of thumb, private(beginning with `_`) methods/functions are + not documented. + + This decorator allows to force document a private method/function. Args: - obj: The object to document, or skip. + obj: The class-attribute to hide from the generated docs. Returns: - True if the object should be skipped + obj """ - # Unwrap fget if the object is a property - if isinstance(obj, property): - obj = obj.fget - return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE) - - -def should_skip_class_attr(cls, name): - """Returns true if docs should be skipped for this class attribute. - - Args: - cls: The class the attribute belongs to. - name: The name of the attribute. - - Returns: - True if the attribute should be skipped. - """ - # Get the object with standard lookup, from the nearest - # defining parent. - try: - obj = getattr(cls, name) - except AttributeError: - # Avoid error caused by enum metaclasses in python3 - if name in ("name", "value"): - return True - raise - - # Unwrap fget if the object is a property - if isinstance(obj, property): - obj = obj.fget - - # Skip if the object is decorated with `do_not_generate_docs` or - # `do_not_doc_inheritable` - if should_skip(obj): - return True - - # Use __dict__ lookup to get the version defined in *this* class. - obj = cls.__dict__.get(name, None) - if isinstance(obj, property): - obj = obj.fget - if obj is not None: - # If not none, the object is defined in *this* class. - # Do not skip if decorated with `for_subclass_implementers`. - if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): - return False - - # for each parent class - for parent in cls.__mro__[1:]: - obj = getattr(parent, name, None) - - if obj is None: - continue - - if isinstance(obj, property): - obj = obj.fget - - # Skip if the parent's definition is decorated with `do_not_doc_inheritable` - # or `for_subclass_implementers` - if hasattr(obj, _DO_NOT_DOC_INHERITABLE): - return True - - if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): - return True - - # No blockng decorators --> don't skip - return False + setattr(obj, _DOC_PRIVATE, None) + return obj diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py deleted file mode 100644 index d5eb4ffc000..00000000000 --- a/tensorflow/tools/docs/doc_controls_test.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for documentation control decorators.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.platform import googletest -from tensorflow.tools.docs import doc_controls - - -class DocControlsTest(googletest.TestCase): - - def test_do_not_generate_docs(self): - - @doc_controls.do_not_generate_docs - def dummy_function(): - pass - - self.assertTrue(doc_controls.should_skip(dummy_function)) - - def test_do_not_doc_on_method(self): - """The simple decorator is not aware of inheritance.""" - - class Parent(object): - - @doc_controls.do_not_generate_docs - def my_method(self): - pass - - class Child(Parent): - - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertFalse( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable(self): - - class Parent(object): - - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class Child(Parent): - - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable_property(self): - - class Parent(object): - - @property - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class Child(Parent): - - @property - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable_staticmethod(self): - - class GrandParent(object): - - def my_method(self): - pass - - class Parent(GrandParent): - - @staticmethod - @doc_controls.do_not_doc_inheritable - def my_method(): - pass - - class Child(Parent): - - @staticmethod - def my_method(): - pass - - class GrandChild(Child): - pass - - self.assertFalse(doc_controls.should_skip(GrandParent.my_method)) - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_for_subclass_implementers(self): - - class GrandParent(object): - - def my_method(self): - pass - - class Parent(GrandParent): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Child(Parent): - pass - - class GrandChild(Child): - - def my_method(self): - pass - - class Grand2Child(Child): - pass - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(Grand2Child, 'my_method')) - - def test_for_subclass_implementers_short_circuit(self): - - class GrandParent(object): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Parent(GrandParent): - - def my_method(self): - pass - - class Child(Parent): - - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class GrandChild(Child): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Grand2Child(Child): - pass - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertFalse( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(Grand2Child, 'my_method')) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py index 0b3b9e00bb6..fd59772cd6a 100644 --- a/tensorflow/tools/docs/generate2.py +++ b/tensorflow/tools/docs/generate2.py @@ -173,7 +173,7 @@ def build_docs(output_dir, code_url_prefix, search_hints, gen_report): if not name.startswith("_"): doc_controls.hide_from_search(obj) - for cls in [tf.Module, tf.keras.layers.Layer]: + for cls in [tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer]: doc_controls.decorate_all_class_attributes( decorator=doc_controls.do_not_doc_in_subclasses, cls=cls, diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c9d2932bcee..f25a6446813 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -37,6 +37,9 @@ transitive_hdrs( "//tensorflow/cc/saved_model:reader", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c:kernels_hdrs", + "//tensorflow/c:ops_hdrs", # WARNING: None of the C/C++ code under python/ has any API guarantees, and TF team # reserves the right to change APIs and other header-level interfaces. If your custom # op uses these headers, it may break when users upgrade their version of tensorflow. @@ -130,7 +133,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/layers/preprocessing:preprocessing_test_utils", "//tensorflow/python/keras/distribute:distribute_test_lib_pip", - "//tensorflow/python/keras/mixed_precision/experimental:test_util", + "//tensorflow/python/keras/mixed_precision:test_util", "//tensorflow/python/keras/tests:model_subclassing_test_util", "//tensorflow/python/keras/tests:model_architectures", "//tensorflow/python/keras/benchmarks:keras_benchmark_lib_pip", @@ -151,9 +154,9 @@ COMMON_PIP_DEPS = [ "//tensorflow/tools/common:test_module1", "//tensorflow/tools/common:traverse", "//tensorflow/python/distribute:parameter_server_strategy_v2", - "//tensorflow/python/distribute/client:client", - "//tensorflow/python/distribute/client:remote_eager_lib", - "//tensorflow/python/distribute/client:metric_utils", + "//tensorflow/python/distribute/coordinator:cluster_coordinator", + "//tensorflow/python/distribute/coordinator:remote_eager_lib", + "//tensorflow/python/distribute/coordinator:metric_utils", ] # On Windows, python binary is a zip file of runfiles tree. diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index e4a69d22a6f..5c012df896f 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -44,6 +44,7 @@ from setuptools import setup from setuptools.command.install import install as InstallCommandBase from setuptools.dist import Distribution + # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. @@ -51,35 +52,10 @@ from setuptools.dist import Distribution # tensorflow/core/public/version.h _VERSION = '2.4.0' -REQUIRED_PACKAGES = [ - 'absl-py >= 0.9.0', - 'astunparse == 1.6.3', - 'flatbuffers >= 1.12', - 'gast == 0.3.3', - 'google_pasta >= 0.1.8', - 'h5py >= 2.10.0, < 2.11.0', - 'keras_preprocessing >= 1.1.1, < 1.2', - # TODO(mihaimaruseac): numpy 1.19.0 has ABI breakage - # https://github.com/numpy/numpy/pull/15355 - 'numpy >= 1.16.0, < 1.19.0', - 'opt_einsum >= 2.3.2', - 'portpicker >= 1.3.0', # Needed by tf.__internal__.distribute.combinations. - 'protobuf >= 3.9.2', - 'tensorboard >= 2.3.0, < 3', - 'tensorflow_estimator >= 2.3.0, < 2.4.0', - 'termcolor >= 1.1.0', - 'typing_extensions >= 3.7.4.2', - 'wrapt >= 1.11.1', - 'wheel >= 0.26', - 'six >= 1.12.0', -] - -if sys.byteorder == 'little': - # grpcio does not build correctly on big-endian machines due to lack of - # BoringSSL support. - # See https://github.com/tensorflow/tensorflow/issues/17882. - REQUIRED_PACKAGES.append('grpcio >= 1.8.6') +# We use the same setup.py for all tensorflow_* packages and for the nightly +# equivalents (tf_nightly_*). The package is controlled from the argument line +# when building the pip package. project_name = 'tensorflow' if '--project_name' in sys.argv: project_name_idx = sys.argv.index('--project_name') @@ -87,13 +63,72 @@ if '--project_name' in sys.argv: sys.argv.remove('--project_name') sys.argv.pop(project_name_idx) -# tf-nightly should depend on tb-nightly + +# All versions of TF need these packages. We use the `~=` syntax to pin packages +# to the latest major.minor release accepting all other patches on top of that. +# If we already know of a patched version, we pin to that. +# For packages that don't have yet a stable release, we pin using `~= 0.x` which +# means we accept any `0.y` version (y >= x) but not the first major release. We +# will need additional testing for that. +# NOTE: This assumes that all packages follow SemVer. If a packages follows a +# different versioning scheme (e.g., PVP), we use different bound specifier and +# comment the versioning scheme. +# NOTE: Please add test only packages to `TEST_PACKAGES` below. +REQUIRED_PACKAGES = [ + 'absl-py ~= 0.10', + 'astunparse ~= 1.6.3', + 'flatbuffers ~= 1.12.0', + 'google_pasta ~= 0.2', + 'h5py ~= 2.10.0', + 'keras_preprocessing ~= 1.1.2', + 'numpy ~= 1.19.2', + 'opt_einsum ~= 3.3.0', + 'protobuf ~= 3.13.0', + 'six ~= 1.15.0', + 'termcolor ~= 1.1.0', + 'typing_extensions ~= 3.7.4', + 'wheel ~= 0.35', + 'wrapt ~= 1.12.1', + # These packages needs to be pinned exactly as newer versions are + # incompatible with the rest of the ecosystem + 'gast == 0.3.3', + # TensorFlow ecosystem packages that TF exposes API for + # These need to be in sync with the existing TF version + # They are updated during the release process + # When updating these, please also update the nightly versions below + 'tensorboard ~= 2.3', + 'tensorflow_estimator ~= 2.3.0', +] + + +# For nightly packages, instead of dependening on tensorboard and +# tensorflow_estimator, we depend on their nightly equivalent. +# When updating these, make sure to also update the release versions above. +# NOTE: the nightly versions are one version ahead of the release ones! +# NOTE: the nightly versions specify alpha/dev! if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 2.4.0a0, < 3.0.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly ~= 2.4.0.a' elif 'tensorflow_estimator' in pkg: - REQUIRED_PACKAGES[i] = 'tf-estimator-nightly' + REQUIRED_PACKAGES[i] = 'tf-estimator-nightly ~= 2.4.0.dev' + + +# grpcio does not build correctly on big-endian machines due to lack of +# BoringSSL support. +# See https://github.com/tensorflow/tensorflow/issues/17882. +if sys.byteorder == 'little': + REQUIRED_PACKAGES.append('grpcio ~= 1.32.0') + + +# Packages which are only needed for testing code. +# Please don't add test-only packages to `REQUIRED_PACKAGES`! +# Follows the same conventions as `REQUIRED_PACKAGES` +TEST_PACKAGES = [ + 'portpicker ~= 1.3.1', + 'scipy ~= 1.5.2', +] + DOCLINES = __doc__.split('\n') if project_name.endswith('-gpu'): @@ -111,6 +146,7 @@ CONSOLE_SCRIPTS = [ 'tflite_convert = tensorflow.lite.python.tflite_convert:main', 'toco = tensorflow.lite.python.tflite_convert:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', + 'import_pb_to_tensorboard = tensorflow.python.tools.import_pb_to_tensorboard:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the # TensorBoard command, pip will inappropriately remove it during install, @@ -126,10 +162,6 @@ CONSOLE_SCRIPTS = [ if 'tf_nightly' in project_name: CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:run_main') -TEST_PACKAGES = [ - 'scipy >= 0.15.1', -] - class BinaryDistribution(Distribution): @@ -297,7 +329,6 @@ setup( 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index e3b9a3557a4..64cb156089a 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -4,7 +4,6 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") -load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") @@ -125,41 +124,15 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): armhf_repo = "../armhf_linux_toolchain", ) - mkl_repository( - name = "mkl_linux", - build_file = clean_dep("//third_party/mkl:mkl.BUILD"), - sha256 = "a936d6b277a33d2a027a024ea8e65df62bd2e162c7ca52c48486ed9d5dc27160", - strip_prefix = "mklml_lnx_2019.0.5.20190502", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/releases/download/v0.21/mklml_lnx_2019.0.5.20190502.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.21/mklml_lnx_2019.0.5.20190502.tgz", - ], - ) - mkl_repository( - name = "mkl_windows", - build_file = clean_dep("//third_party/mkl:mkl.BUILD"), - sha256 = "33cc27652df3b71d7cb84b26718b5a2e8965e2c864a502347db02746d0430d57", - strip_prefix = "mklml_win_2020.0.20190813", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/releases/download/v0.21/mklml_win_2020.0.20190813.zip", - "https://github.com/intel/mkl-dnn/releases/download/v0.21/mklml_win_2020.0.20190813.zip", - ], - ) - mkl_repository( - name = "mkl_darwin", - build_file = clean_dep("//third_party/mkl:mkl.BUILD"), - sha256 = "2fbb71a0365d42a39ea7906568d69b1db3bfc9914fee75eedb06c5f32bf5fa68", - strip_prefix = "mklml_mac_2019.0.5.20190502", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/releases/download/v0.21/mklml_mac_2019.0.5.20190502.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.21/mklml_mac_2019.0.5.20190502.tgz", - ], - ) - if path_prefix: print("path_prefix was specified to tf_workspace but is no longer used " + "and will be removed in the future.") + # To update any of the dependencies bellow: + # a) update URL and strip_prefix to the new git commit hash + # b) get the sha256 hash of the commit by running: + # curl -L | sha256sum + # and update the sha256 with the result. tf_http_archive( name = "XNNPACK", sha256 = "4b199c96fb2d551450b48eb5549843b41c023ad200aa86760a7c56d0dc0da806", @@ -190,11 +163,6 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ], ) - # Important: If you are upgrading MKL-DNN, then update the version numbers - # in third_party/mkl_dnn/mkldnn.BUILD. In addition, the new version of - # MKL-DNN might require upgrading MKL ML libraries also. If they need to be - # upgraded then update the version numbers on all three versions above - # (Linux, Mac, Windows). tf_http_archive( name = "mkl_dnn", build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), @@ -235,11 +203,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "00ff67c15f8e8faf14495482e7396cc1d99cdfaaa2151f4aafef92bc754e634d", # SHARED_EIGEN_SHA - strip_prefix = "eigen-22c971a225dbb567cd1a45f6006d16c4aa618551", + sha256 = "e807a6a6f3a0e8ab10adeb59bb5a9bbb113e8e1684f9b4b32f73f58fd758b4cf", # SHARED_EIGEN_SHA + strip_prefix = "eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/22c971a225dbb567cd1a45f6006d16c4aa618551/eigen-22c971a225dbb567cd1a45f6006d16c4aa618551.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/22c971a225dbb567cd1a45f6006d16c4aa618551/eigen-22c971a225dbb567cd1a45f6006d16c4aa618551.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/011e0db31d1bed8b7f73662be6d57d9f30fa457a/eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/011e0db31d1bed8b7f73662be6d57d9f30fa457a/eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a.tar.gz", ], ) @@ -712,8 +680,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "6713332fddb796f5b14fcb6a7e5d36979676e4ab" - LLVM_SHA256 = "d83051da693c4165a8bd9a2703c806820d5146188b1036fd94300fafa09aae50" + LLVM_COMMIT = "c89447b65984c97145f63be21e42cfa98da60dd2" + LLVM_SHA256 = "b35dd27eace459897c07faa333b0cb9ddc0ef260b20582dd04b6910d548a7e08" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), @@ -730,6 +698,18 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): }, ) + # Intel openMP that is part of LLVM sources. + tf_http_archive( + name = "llvm_openmp", + build_file = clean_dep("//third_party/llvm_openmp:BUILD"), + sha256 = "d19f728c8e04fb1e94566c8d76aef50ec926cd2f95ef3bf1e0a5de4909b28b44", + strip_prefix = "openmp-10.0.1.src", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/releases/download/llvmorg-10.0.1/openmp-10.0.1.src.tar.xz", + "https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.1/openmp-10.0.1.src.tar.xz", + ], + ) + tf_http_archive( name = "lmdb", build_file = clean_dep("//third_party:lmdb.BUILD"), @@ -1154,11 +1134,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "pybind11", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", - "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.6.0.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz", ], - sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", - strip_prefix = "pybind11-2.4.3", + sha256 = "90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571", + strip_prefix = "pybind11-2.6.0", build_file = clean_dep("//third_party:pybind11.BUILD"), system_build_file = clean_dep("//third_party/systemlibs:pybind11.BUILD"), ) diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel index 15cfcd1c4ee..9b007cc0daa 100644 --- a/third_party/cpuinfo/BUILD.bazel +++ b/third_party/cpuinfo/BUILD.bazel @@ -102,6 +102,7 @@ cc_library( ":linux_armv7a": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, ":linux_armeabi": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS, ":linux_aarch64": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM64_SRCS, + ":linux_mips64": COMMON_SRCS + LINUX_SRCS, ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, ":windows_x86_64": COMMON_SRCS + X86_SRCS + WINDOWS_X86_SRCS, ":android_armv7": COMMON_SRCS + ARM_SRCS + LINUX_SRCS + LINUX_ARM32_SRCS + ANDROID_ARM_SRCS, @@ -208,6 +209,11 @@ config_setting( values = {"cpu": "aarch64"}, ) +config_setting( + name = "linux_mips64", + values = {"cpu": "mips64"}, +) + config_setting( name = "macos_x86_64", values = { diff --git a/third_party/eigen3/gpu_packet_math.patch b/third_party/eigen3/gpu_packet_math.patch index fdc8961b93d..c0f466c24d3 100644 --- a/third_party/eigen3/gpu_packet_math.patch +++ b/third_party/eigen3/gpu_packet_math.patch @@ -23,3 +23,76 @@ diff -ru a/Eigen/src/Geometry/arch/Geometry_SSE.h b/Eigen/src/Geometry/arch/Geom return res; } }; +diff -ru a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h +--- a/Eigen/src/Core/GenericPacketMath.h ++++ b/Eigen/src/Core/GenericPacketMath.h +@@ -255,49 +255,43 @@ + return std::complex(b, b); + } + +-template +-EIGEN_DEVICE_FUNC inline Packet bitwise_helper(const Packet& a, const Packet& b, Op op) { ++/** \internal \returns the bitwise and of \a a and \a b */ ++template EIGEN_DEVICE_FUNC inline Packet ++pand(const Packet& a, const Packet& b) { + const unsigned char* a_ptr = reinterpret_cast(&a); + const unsigned char* b_ptr = reinterpret_cast(&b); + Packet c; + unsigned char* c_ptr = reinterpret_cast(&c); + for (size_t i = 0; i < sizeof(Packet); ++i) { +- *c_ptr++ = op(*a_ptr++, *b_ptr++); ++ *c_ptr++ = *a_ptr++ & *b_ptr++; + } + return c; + } + +-/** \internal \returns the bitwise and of \a a and \a b */ +-template EIGEN_DEVICE_FUNC inline Packet +-pand(const Packet& a, const Packet& b) { +-#if defined(EIGEN_HIP_DEVICE_COMPILE) +- return bitwise_helper(a ,b, std::bit_and()); +-#else +- EIGEN_USING_STD(bit_and); +- return bitwise_helper(a ,b, bit_and()); +-#endif +-} +- + /** \internal \returns the bitwise or of \a a and \a b */ + template EIGEN_DEVICE_FUNC inline Packet + por(const Packet& a, const Packet& b) { +-#if defined(EIGEN_HIP_DEVICE_COMPILE) +- return bitwise_helper(a ,b, std::bit_or()); +-#else +- EIGEN_USING_STD(bit_or); +- return bitwise_helper(a ,b, bit_or()); +-#endif ++ const unsigned char* a_ptr = reinterpret_cast(&a); ++ const unsigned char* b_ptr = reinterpret_cast(&b); ++ Packet c; ++ unsigned char* c_ptr = reinterpret_cast(&c); ++ for (size_t i = 0; i < sizeof(Packet); ++i) { ++ *c_ptr++ = *a_ptr++ | *b_ptr++; ++ } ++ return c; + } + + /** \internal \returns the bitwise xor of \a a and \a b */ + template EIGEN_DEVICE_FUNC inline Packet + pxor(const Packet& a, const Packet& b) { +-#if defined(EIGEN_HIP_DEVICE_COMPILE) +- return bitwise_helper(a ,b, std::bit_xor()); +-#else +- EIGEN_USING_STD(bit_xor); +- return bitwise_helper(a ,b, bit_xor()); +-#endif ++ const unsigned char* a_ptr = reinterpret_cast(&a); ++ const unsigned char* b_ptr = reinterpret_cast(&b); ++ Packet c; ++ unsigned char* c_ptr = reinterpret_cast(&c); ++ for (size_t i = 0; i < sizeof(Packet); ++i) { ++ *c_ptr++ = *a_ptr++ ^ *b_ptr++; ++ } ++ return c; + } + + /** \internal \returns the bitwise and of \a a and not \a b */ diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index a4a21abc367..70eacf82883 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -127,6 +127,13 @@ cc_library( linkstatic = 1, ) +cc_library( + name = "cublasLt", + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], + linkstatic = 1, +) + cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], @@ -168,6 +175,7 @@ cc_library( name = "cuda", deps = [ ":cublas", + ":cublasLt", ":cuda_headers", ":cudart", ":cudnn", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 704003b7f63..3ba34470b93 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): cuda_config.cublas_version, static = False, ), + "cublasLt": _check_cuda_lib_params( + "cublasLt", + cpu_value, + cuda_config.config["cublas_library_dir"], + cuda_config.cublas_version, + static = False, + ), "cusolver": _check_cuda_lib_params( "cusolver", cpu_value, @@ -780,6 +787,7 @@ def _create_dummy_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_lib}": lib_name("cudart", cpu_value), "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), "%{cusolver_lib}": lib_name("cusolver", cpu_value), "%{cudnn_lib}": lib_name("cudnn", cpu_value), "%{cufft_lib}": lib_name("cufft", cpu_value), @@ -811,6 +819,7 @@ filegroup(name="cudnn-include") "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), ) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) @@ -1002,11 +1011,13 @@ def _create_local_cuda_repository(repository_ctx): cublas_include_path + "/cublas.h", cublas_include_path + "/cublas_v2.h", cublas_include_path + "/cublas_api.h", + cublas_include_path + "/cublasLt.h", ], outs = [ "cublas/include/cublas.h", "cublas/include/cublas_v2.h", "cublas/include/cublas_api.h", + "cublas/include/cublasLt.h", ], )) @@ -1147,6 +1158,7 @@ def _create_local_cuda_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), diff --git a/third_party/llvm/BUILD b/third_party/llvm/BUILD index 1a5634a6285..88f1574ce9e 100644 --- a/third_party/llvm/BUILD +++ b/third_party/llvm/BUILD @@ -2,5 +2,8 @@ py_binary( name = "expand_cmake_vars", srcs = ["expand_cmake_vars.py"], srcs_version = "PY2AND3", - visibility = ["@llvm-project//:__subpackages__"], + visibility = [ + "@llvm-project//:__subpackages__", + "@llvm_openmp//:__subpackages__", + ], ) diff --git a/third_party/llvm/expand_cmake_vars.py b/third_party/llvm/expand_cmake_vars.py index ffc6a255fd1..a8a4b9673ed 100644 --- a/third_party/llvm/expand_cmake_vars.py +++ b/third_party/llvm/expand_cmake_vars.py @@ -25,6 +25,7 @@ import sys _CMAKE_DEFINE_REGEX = re.compile(r"\s*#cmakedefine\s+([A-Za-z_0-9]*)(\s.*)?$") _CMAKE_DEFINE01_REGEX = re.compile(r"\s*#cmakedefine01\s+([A-Za-z_0-9]*)") _CMAKE_VAR_REGEX = re.compile(r"\${([A-Za-z_0-9]*)}") +_CMAKE_ATVAR_REGEX = re.compile(r"@([A-Za-z_0-9]*)@") def _parse_args(argv): @@ -37,10 +38,10 @@ def _parse_args(argv): def _expand_variables(input_str, cmake_vars): - """Expands ${VARIABLE}s in 'input_str', using dictionary 'cmake_vars'. + """Expands ${VARIABLE}s and @VARIABLE@s in 'input_str', using dictionary 'cmake_vars'. Args: - input_str: the string containing ${VARIABLE} expressions to expand. + input_str: the string containing ${VARIABLE} or @VARIABLE@ expressions to expand. cmake_vars: a dictionary mapping variable names to their values. Returns: @@ -50,7 +51,7 @@ def _expand_variables(input_str, cmake_vars): if match.group(1) in cmake_vars: return cmake_vars[match.group(1)] return "" - return _CMAKE_VAR_REGEX.sub(replace, input_str) + return _CMAKE_ATVAR_REGEX.sub(replace,_CMAKE_VAR_REGEX.sub(replace, input_str)) def _expand_cmakedefines(line, cmake_vars): diff --git a/third_party/llvm_openmp/BUILD b/third_party/llvm_openmp/BUILD new file mode 100644 index 00000000000..099a84dcbaa --- /dev/null +++ b/third_party/llvm_openmp/BUILD @@ -0,0 +1,205 @@ +# Build file for OpenMP library that is part of llvm + +load( + "@org_tensorflow//third_party/llvm:llvm.bzl", + "cmake_var_string", + "expand_cmake_vars", +) +load( + "@org_tensorflow//third_party/llvm_openmp:openmp.bzl", + "dict_add", +) + +exports_files(["LICENSE.txt"]) + +genrule( + name = "kmp_i18n_id", + srcs = [ + "runtime/tools/message-converter.pl", + "runtime/src/i18n/en_US.txt", + ], + outs = ["include/kmp_i18n_id.inc"], + cmd = "perl $(location runtime/tools/message-converter.pl) --os=lin --prefix=kmp_i18n --enum=$@ $(location runtime/src/i18n/en_US.txt)", +) + +genrule( + name = "kmp_i18n_default", + srcs = [ + "runtime/tools/message-converter.pl", + "runtime/src/i18n/en_US.txt", + ], + outs = ["include/kmp_i18n_default.inc"], + cmd = "perl $(location runtime/tools/message-converter.pl) --os=lin --prefix=kmp_i18n --default=$@ $(location runtime/src/i18n/en_US.txt)", +) + +# Bazel doesn't accept .txt as an input, rename the ldscript to .inc to workaround. +genrule( + name = "ldscript", + srcs = ["runtime/src/exports_so.txt"], + outs = ["exports_so.inc"], + cmd = "cp $(location runtime/src/exports_so.txt) $@", +) + +genrule( + name = "openmp_asm", + srcs = [ + "runtime/src/z_Windows_NT-586_asm.asm", + ], + outs = [ + "z_Windows_NT-586_asm.S", + ], + cmd = "cp $(location runtime/src/z_Windows_NT-586_asm.asm) $@", + visibility = ["//visibility:public"], +) + +# Common Cmake vars to expand. +omp_vars = { + "LIBOMP_ENABLE_SHARED": 1, + "LIBOMP_LEGAL_ARCH": "Intel(R) 64", + "LIBOMP_LIB_FILE": "libiomp5", + "LIBOMP_VERSION_MAJOR": 5, + "LIBOMP_VERSION_MINOR": 0, +} + +# Linux Cmake vars to expand. +omp_vars_linux = { + "LIBOMP_USE_VERSION_SYMBOLS": 1, + "LIBOMP_HAVE_WEAK_ATTRIBUTE": 1, + "LIBOMP_USE_ADAPTIVE_LOCKS": 1, + "LIBOMP_ENABLE_ASSERTIONS": 1, +} + +# Windows Cmake vars to expand. +omp_vars_win = { + "MSVC": 1, +} + +omp_all_cmake_vars = select({ + "@org_tensorflow//tensorflow:windows": cmake_var_string( + dict_add( + omp_vars, + omp_vars_win, + ), + ), + "//conditions:default": cmake_var_string( + dict_add( + omp_vars, + omp_vars_linux, + ), + ), +}) + +expand_cmake_vars( + name = "config_kmp", + src = "runtime/src/kmp_config.h.cmake", + cmake_vars = omp_all_cmake_vars, + dst = "include/kmp_config.h", +) + +expand_cmake_vars( + name = "config_omp", + src = "runtime/src/include/omp.h.var", + cmake_vars = omp_all_cmake_vars, + dst = "include/omp.h", +) + +cppsources = [ + "runtime/src/kmp_alloc.cpp", + "runtime/src/kmp_atomic.cpp", + "runtime/src/kmp_csupport.cpp", + "runtime/src/kmp_debug.cpp", + "runtime/src/kmp_itt.cpp", + "runtime/src/kmp_environment.cpp", + "runtime/src/kmp_error.cpp", + "runtime/src/kmp_global.cpp", + "runtime/src/kmp_i18n.cpp", + "runtime/src/kmp_io.cpp", + "runtime/src/kmp_runtime.cpp", + "runtime/src/kmp_settings.cpp", + "runtime/src/kmp_str.cpp", + "runtime/src/kmp_tasking.cpp", + "runtime/src/kmp_threadprivate.cpp", + "runtime/src/kmp_utility.cpp", + "runtime/src/kmp_barrier.cpp", + "runtime/src/kmp_wait_release.cpp", + "runtime/src/kmp_affinity.cpp", + "runtime/src/kmp_dispatch.cpp", + "runtime/src/kmp_lock.cpp", + "runtime/src/kmp_sched.cpp", + "runtime/src/kmp_taskdeps.cpp", + "runtime/src/kmp_cancel.cpp", + "runtime/src/kmp_ftn_cdecl.cpp", + "runtime/src/kmp_ftn_extra.cpp", + "runtime/src/kmp_version.cpp", +] + +srcdeps = [ + ":config_kmp", + ":config_omp", + ":kmp_i18n_id", + ":kmp_i18n_default", + ":ldscript", +] + +common_includes = [ + "runtime/src/", + "include/", +] + +# TODO(Intel-tf) Replace the following 3 calls to cc_binary with cc_library. +# cc_library should be used for files that are not independently executed. Using +# cc_library results in linking errors. For e.g on Linux, the build fails +# with the following error message. +# ERROR: //tensorflow/BUILD:689:1: Linking of rule '//tensorflow:libtensorflow_framework.so.2.4.0' failed (Exit 1) +# /usr/bin/ld.gold: error: symbol GOMP_parallel_loop_nonmonotonic_guided has undefined version VERSION +# /usr/bin/ld.gold: error: symbol GOMP_parallel_start has undefined version GOMP_1.0 +# /usr/bin/ld.gold: error: symbol GOMP_cancellation_point has undefined version GOMP_4.0 +# /usr/bin/ld.gold: error: symbol omp_set_num_threads has undefined version OMP_1.0 +# ...... +# ...... + +cc_binary( + name = "libiomp5.so", + srcs = cppsources + [ + #linux specific files + "runtime/src/z_Linux_util.cpp", + "runtime/src/kmp_gsupport.cpp", + "runtime/src/z_Linux_asm.S", + ] + srcdeps, + copts = ["-Domp_EXPORTS -D_GNU_SOURCE -D_REENTRANT"], + includes = common_includes, + linkopts = ["-lpthread -ldl -Wl,--version-script=$(location :ldscript)"], + linkshared = True, + visibility = ["//visibility:public"], +) + +cc_binary( + name = "libiomp5md.dll", + srcs = cppsources + [ + #window specific files + "runtime/src/z_Windows_NT_util.cpp", + "runtime/src/z_Windows_NT-586_util.cpp", + ] + srcdeps + [":openmp_asm"], + copts = ["/Domp_EXPORTS /D_M_AMD64 /DOMPT_SUPPORT=0 /D_WINDOWS /D_WINNT /D_USRDLL"], + includes = common_includes, + linkopts = ["/MACHINE:X64"], + linkshared = True, + visibility = ["//visibility:public"], +) + +# MacOS build has not been tested, however since the MacOS build of openmp +# uses the same configuration as Linux, the following should work. +cc_binary( + name = "libiomp5.dylib", + srcs = cppsources + [ + #linux/MacOS specific files + "runtime/src/z_Linux_util.cpp", + "runtime/src/kmp_gsupport.cpp", + "runtime/src/z_Linux_asm.S", + ] + srcdeps, + copts = ["-Domp_EXPORTS -D_GNU_SOURCE -D_REENTRANT"], + includes = common_includes, + linkopts = ["-lpthread -ldl -Wl,--version-script=$(location :ldscript)"], + linkshared = True, + visibility = ["//visibility:public"], +) diff --git a/third_party/llvm_openmp/openmp.bzl b/third_party/llvm_openmp/openmp.bzl new file mode 100644 index 00000000000..9f428b5b37d --- /dev/null +++ b/third_party/llvm_openmp/openmp.bzl @@ -0,0 +1,21 @@ +"""This file contains BUILD extensions for building llvm_openmp. +TODO(Intel-tf): Delete this and reuse a similar function in third_party/llvm +after the TF 2.4 branch cut has passed. +""" + +def dict_add(*dictionaries): + """Returns a new `dict` that has all the entries of the given dictionaries. + + If the same key is present in more than one of the input dictionaries, the + last of them in the argument list overrides any earlier ones. + + Args: + *dictionaries: Zero or more dictionaries to be added. + + Returns: + A new `dict` that has all the entries of the given dictionaries. + """ + result = {} + for d in dictionaries: + result.update(d) + return result diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index c1c2c450e34..aa65b585b85 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -21,6 +21,30 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "build_with_mkl_lnx_openmp", + constraint_values = [ + "@platforms//os:linux", + ], + define_values = { + "build_with_mkl": "true", + "build_with_openmp": "true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "build_with_mkl_windows_openmp", + constraint_values = [ + "@platforms//os:windows", + ], + define_values = { + "build_with_mkl": "true", + "build_with_openmp": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "build_with_mkl_aarch64", define_values = { @@ -40,18 +64,38 @@ config_setting( filegroup( name = "LICENSE", - srcs = ["MKL_LICENSE"] + select({ - "@org_tensorflow//tensorflow:linux_x86_64": [ - "@mkl_linux//:LICENSE", - ], - "@org_tensorflow//tensorflow:macos": [ - "@mkl_darwin//:LICENSE", - ], - "@org_tensorflow//tensorflow:windows": [ - "@mkl_windows//:LICENSE", - ], - "//conditions:default": [], - }), + srcs = [ + "MKL_LICENSE", + "@llvm_openmp//:LICENSE.txt", + ], + visibility = ["//visibility:public"], +) + +# TODO(Intel-tf) Remove the following 3 calls to cc_library and replace all uses +# of mkl_libs_* with @llvm_openmp//:libiomp5.* directly. + +cc_library( + name = "mkl_libs_linux", + srcs = [ + "@llvm_openmp//:libiomp5.so", + ], + visibility = ["//visibility:public"], +) + +# MacOS build configuration is provided for completness, it has not been tested +cc_library( + name = "mkl_libs_darwin", + srcs = [ + "@llvm_openmp//:libiomp5.dylib", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "mkl_libs_windows", + srcs = [ + "@llvm_openmp//:libiomp5md.dll", + ], visibility = ["//visibility:public"], ) @@ -60,16 +104,13 @@ cc_library( visibility = ["//visibility:public"], deps = select({ "@org_tensorflow//tensorflow:linux_x86_64": [ - "@mkl_linux//:mkl_headers", - "@mkl_linux//:mkl_libs_linux", + ":mkl_libs_linux", ], "@org_tensorflow//tensorflow:macos": [ - "@mkl_darwin//:mkl_headers", - "@mkl_darwin//:mkl_libs_darwin", + ":mkl_libs_darwin", ], "@org_tensorflow//tensorflow:windows": [ - "@mkl_windows//:mkl_headers", - "@mkl_windows//:mkl_libs_windows", + ":mkl_libs_windows", ], "//conditions:default": [], }), diff --git a/third_party/mkl/mkl.BUILD b/third_party/mkl/mkl.BUILD deleted file mode 100644 index 72370182c41..00000000000 --- a/third_party/mkl/mkl.BUILD +++ /dev/null @@ -1,46 +0,0 @@ -licenses(["notice"]) # 3-Clause BSD - -exports_files(["license.txt"]) - -filegroup( - name = "LICENSE", - srcs = [ - "license.txt", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "mkl_headers", - srcs = glob(["include/*(.cc|.cpp|.cxx|.c++|.C|.c|.h|.hh|.hpp|.ipp|.hxx|.inc|.S|.s|.asm|.a|.lib|.pic.a|.lo|.lo.lib|.pic.lo|.so|.dylib|.dll|.o|.obj|.pic.o)"]), - includes = ["include"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "mkl_libs_linux", - srcs = [ - "lib/libiomp5.so", - "lib/libmklml_intel.so", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "mkl_libs_darwin", - srcs = [ - "lib/libiomp5.dylib", - "lib/libmklml.dylib", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "mkl_libs_windows", - srcs = [ - "lib/libiomp5md.lib", - "lib/mklml.lib", - ], - linkopts = ["/FORCE:MULTIPLE"], - visibility = ["//visibility:public"], -) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 8e7a3d61564..f88d50dfc19 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -1,5 +1,9 @@ exports_files(["LICENSE"]) +load( + "@org_tensorflow//tensorflow:tensorflow.bzl", + "tf_openmp_copts", +) load( "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", @@ -14,14 +18,6 @@ load( "template_rule", ) -config_setting( - name = "clang_linux_x86_64", - values = { - "cpu": "k8", - "define": "using_clang=true", - }, -) - _DNNL_RUNTIME_OMP = { "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", @@ -85,15 +81,7 @@ cc_library( "-fexceptions", "-UUSE_MKL", "-UUSE_CBLAS", - ] + select({ - "@org_tensorflow//tensorflow:linux_x86_64": [ - "-fopenmp", # only works with gcc - ], - # TODO(ibiryukov): enable openmp with clang by including libomp as a - # dependency. - ":clang_linux_x86_64": [], - "//conditions:default": [], - }), + ] + tf_openmp_copts(), includes = [ "include", "src", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 31a839d81fa..d129c475a0d 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -458,6 +458,7 @@ filegroup( srcs = [ "include/mlir/Dialect/StandardOps/IR/Ops.td", "include/mlir/IR/OpAsmInterface.td", + "include/mlir/IR/SymbolInterfaces.td", "include/mlir/Interfaces/CallInterfaces.td", "include/mlir/Interfaces/ControlFlowInterfaces.td", "include/mlir/Interfaces/SideEffectInterfaces.td", @@ -3918,6 +3919,7 @@ cc_library( "include/mlir/Dialect/Linalg/EDSC/Builders.h", "include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h", "include/mlir/Dialect/Linalg/Passes.h", + "include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h", "include/mlir/Dialect/Linalg/Transforms/Hoisting.h", "include/mlir/Dialect/Linalg/Transforms/Transforms.h", "include/mlir/Dialect/Linalg/Utils/Utils.h", @@ -3945,6 +3947,7 @@ cc_library( ":Transforms", ":TransformsPassIncGen", ":VectorOps", + ":VectorToSCF", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", ], @@ -4033,7 +4036,6 @@ cc_library( ":EDSC", ":IR", ":LLVMDialect", - ":LinalgTransforms", ":Pass", ":SCFDialect", ":StandardOps", diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index d88190ce60a..c46788047e3 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -17,6 +17,21 @@ cc_library( includes = ["."], ) +filegroup( + name = "TestOpTdFiles", + srcs = [ + "lib/Dialect/Test/TestOps.td", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td", + "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td", + "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + ], +) + gentbl( name = "TestOpsIncGen", strip_include_prefix = "lib/Dialect/Test", @@ -57,14 +72,7 @@ gentbl( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/Test/TestOps.td", td_srcs = [ - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td", - "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td", - "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + ":TestOpTdFiles", ], test = True, ) @@ -90,11 +98,34 @@ gentbl( test = True, ) +gentbl( + name = "TestTypeDefsIncGen", + strip_include_prefix = "lib/Dialect/Test", + tbl_outs = [ + ( + "-gen-typedef-decls", + "lib/Dialect/Test/TestTypeDefs.h.inc", + ), + ( + "-gen-typedef-defs", + "lib/Dialect/Test/TestTypeDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Test/TestTypeDefs.td", + td_srcs = [ + ":TestOpTdFiles", + ], + test = True, +) + cc_library( name = "TestDialect", srcs = [ "lib/Dialect/Test/TestDialect.cpp", "lib/Dialect/Test/TestPatterns.cpp", + "lib/Dialect/Test/TestTraits.cpp", + "lib/Dialect/Test/TestTypes.cpp", ], hdrs = [ "lib/Dialect/Test/TestDialect.h", @@ -106,6 +137,7 @@ cc_library( deps = [ ":TestInterfacesIncGen", ":TestOpsIncGen", + ":TestTypeDefsIncGen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", diff --git a/third_party/remote_config/common.bzl b/third_party/remote_config/common.bzl index 2d627e26cb8..d7e69326205 100644 --- a/third_party/remote_config/common.bzl +++ b/third_party/remote_config/common.bzl @@ -41,15 +41,24 @@ def get_python_bin(repository_ctx): python_bin = get_host_environ(repository_ctx, PYTHON_BIN_PATH) if python_bin != None: return python_bin - python_bin_path = which(repository_ctx, "python") - if python_bin_path == None: - auto_config_fail("Cannot find python in PATH, please make sure " + - "python is installed and add its directory in PATH, or --define " + - "%s='/something/else'.\nPATH=%s" % ( - PYTHON_BIN_PATH, - get_environ("PATH", ""), - )) - return python_bin_path + + # First check for an explicit "python3" + python_bin = which(repository_ctx, "python3") + if python_bin != None: + return python_bin + + # Some systems just call pythone3 "python" + python_bin = which(repository_ctx, "python") + if python_bin != None: + return python_bin + + auto_config_fail("Cannot find python in PATH, please make sure " + + "python is installed and add its directory in PATH, or --define " + + "%s='/something/else'.\nPATH=%s" % ( + PYTHON_BIN_PATH, + get_environ("PATH", ""), + )) + return python_bin # unreachable def get_bash_bin(repository_ctx): """Gets the bash bin path. diff --git a/third_party/remote_config/remote_platform_configure.bzl b/third_party/remote_config/remote_platform_configure.bzl index 386ad603950..29520396905 100644 --- a/third_party/remote_config/remote_platform_configure.bzl +++ b/third_party/remote_config/remote_platform_configure.bzl @@ -22,6 +22,8 @@ def _remote_platform_configure_impl(repository_ctx): cpu = "aarch64" elif machine_type.startswith("arm"): cpu = "arm" + elif machine_type.startswith("mips64"): + cpu = "mips64" exec_properties = repository_ctx.attr.platform_exec_properties diff --git a/third_party/systemlibs/absl_py.absl.testing.BUILD b/third_party/systemlibs/absl_py.absl.testing.BUILD index ef514a40478..ee810f8f210 100644 --- a/third_party/systemlibs/absl_py.absl.testing.BUILD +++ b/third_party/systemlibs/absl_py.absl.testing.BUILD @@ -9,3 +9,8 @@ py_library( name = "absltest", visibility = ["//visibility:public"], ) + +py_library( + name = "flagsaver", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/protobuf.BUILD b/third_party/systemlibs/protobuf.BUILD index ccf2ab4dc7d..867ef3577c4 100644 --- a/third_party/systemlibs/protobuf.BUILD +++ b/third_party/systemlibs/protobuf.BUILD @@ -12,29 +12,48 @@ filegroup( visibility = ["//visibility:public"], ) -PROTO_FILES = [ - "google/protobuf/any.proto", - "google/protobuf/api.proto", - "google/protobuf/compiler/plugin.proto", - "google/protobuf/descriptor.proto", - "google/protobuf/duration.proto", - "google/protobuf/empty.proto", - "google/protobuf/field_mask.proto", - "google/protobuf/source_context.proto", - "google/protobuf/struct.proto", - "google/protobuf/timestamp.proto", - "google/protobuf/type.proto", - "google/protobuf/wrappers.proto", -] +# Map of all well known protos. +# name => (include path, imports) +WELL_KNOWN_PROTO_MAP = { + "any": ("google/protobuf/any.proto", []), + "api": ( + "google/protobuf/api.proto", + [ + "source_context", + "type", + ], + ), + "compiler_plugin": ( + "google/protobuf/compiler/plugin.proto", + ["descriptor"], + ), + "descriptor": ("google/protobuf/descriptor.proto", []), + "duration": ("google/protobuf/duration.proto", []), + "empty": ("google/protobuf/empty.proto", []), + "field_mask": ("google/protobuf/field_mask.proto", []), + "source_context": ("google/protobuf/source_context.proto", []), + "struct": ("google/protobuf/struct.proto", []), + "timestamp": ("google/protobuf/timestamp.proto", []), + "type": ( + "google/protobuf/type.proto", + [ + "any", + "source_context", + ], + ), + "wrappers": ("google/protobuf/wrappers.proto", []), +} + +RELATIVE_WELL_KNOWN_PROTOS = [proto[1][0] for proto in WELL_KNOWN_PROTO_MAP.items()] genrule( name = "link_proto_files", - outs = PROTO_FILES, + outs = RELATIVE_WELL_KNOWN_PROTOS, cmd = """ for i in $(OUTS); do f=$${i#$(@D)/} mkdir -p $(@D)/$${f%/*} - ln -sf $(INCLUDEDIR)/$$f $(@D)/$$f + ln -sf $(PROTOBUF_INCLUDE_PATH)/$$f $(@D)/$$f done """, ) @@ -85,74 +104,9 @@ py_library( visibility = ["//visibility:public"], ) -proto_library( - name = "any_proto", - srcs = ["google/protobuf/any.proto"], +[proto_library( + name = proto[0] + "_proto", + srcs = [proto[1][0]], visibility = ["//visibility:public"], -) - -proto_library( - name = "api_proto", - srcs = ["google/protobuf/api.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "compiler_plugin_proto", - srcs = ["google/protobuf/compiler/plugin.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "descriptor_proto", - srcs = ["google/protobuf/descriptor.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "duration_proto", - srcs = ["google/protobuf/duration.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "empty_proto", - srcs = ["google/protobuf/empty.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "field_mask_proto", - srcs = ["google/protobuf/field_mask.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "source_context_proto", - srcs = ["google/protobuf/source_context.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "struct_proto", - srcs = ["google/protobuf/struct.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "timestamp_proto", - srcs = ["google/protobuf/timestamp.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "type_proto", - srcs = ["google/protobuf/type.proto"], - visibility = ["//visibility:public"], -) - -proto_library( - name = "wrappers_proto", - srcs = ["google/protobuf/wrappers.proto"], - visibility = ["//visibility:public"], -) + deps = [dep + "_proto" for dep in proto[1][1]], +) for proto in WELL_KNOWN_PROTO_MAP.items()]